Skip to content

Commit

Permalink
Merge pull request #2925 from mraspaud/fix-sar-type
Browse files Browse the repository at this point in the history
Fix types to allow float32 computations for SAR-C
  • Loading branch information
mraspaud authored Oct 17, 2024
2 parents 56981ff + 5634073 commit 9ebda32
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 31 deletions.
57 changes: 32 additions & 25 deletions satpy/readers/sar_c_safe.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def _dictify(r):
return int(r.text)
except ValueError:
try:
return float(r.text)
return np.float32(r.text)
except ValueError:
return r.text
for x in r.findall("./*"):
Expand Down Expand Up @@ -186,7 +186,7 @@ def get_dataset(self, key, info, chunks=None):

def get_calibration_constant(self):
"""Load the calibration constant."""
return float(self.root.find(".//absoluteCalibrationConstant").text)
return np.float32(self.root.find(".//absoluteCalibrationConstant").text)

def _get_calibration_uncached(self, calibration, chunks=None):
"""Get the calibration array."""
Expand Down Expand Up @@ -341,7 +341,7 @@ def _get_array_pieces_for_current_line(self, current_line):
current_blocks = self._find_blocks_covering_line(current_line)
current_blocks.sort(key=(lambda x: x.coords["x"][0]))
next_line = self._get_next_start_line(current_blocks, current_line)
current_y = np.arange(current_line, next_line)
current_y = np.arange(current_line, next_line, dtype=np.uint16)
pieces = [arr.sel(y=current_y) for arr in current_blocks]
return pieces

Expand Down Expand Up @@ -389,7 +389,7 @@ def _get_padded_dask_pieces(self, pieces, chunks):
@staticmethod
def _fill_dask_pieces(dask_pieces, shape, chunks):
if shape[1] > 0:
new_piece = da.full(shape, np.nan, chunks=chunks)
new_piece = da.full(shape, np.nan, chunks=chunks, dtype=np.float32)
dask_pieces.append(new_piece)


Expand Down Expand Up @@ -425,11 +425,10 @@ def expand(self, chunks):
# corr = 1.5
data = self.lut * corr

x_coord = np.arange(self.first_pixel, self.last_pixel + 1)
y_coord = np.arange(self.first_line, self.last_line + 1)

new_arr = (da.ones((len(y_coord), len(x_coord)), chunks=chunks) *
np.interp(y_coord, self.lines, data)[:, np.newaxis])
x_coord = np.arange(self.first_pixel, self.last_pixel + 1, dtype=np.uint16)
y_coord = np.arange(self.first_line, self.last_line + 1, dtype=np.uint16)
new_arr = (da.ones((len(y_coord), len(x_coord)), dtype=np.float32, chunks=chunks) *
np.interp(y_coord, self.lines, data)[:, np.newaxis].astype(np.float32))
new_arr = xr.DataArray(new_arr,
dims=["y", "x"],
coords={"x": x_coord,
Expand All @@ -438,29 +437,29 @@ def expand(self, chunks):

@property
def first_pixel(self):
return int(self.element.find("firstRangeSample").text)
return np.uint16(self.element.find("firstRangeSample").text)

@property
def last_pixel(self):
return int(self.element.find("lastRangeSample").text)
return np.uint16(self.element.find("lastRangeSample").text)

@property
def first_line(self):
return int(self.element.find("firstAzimuthLine").text)
return np.uint16(self.element.find("firstAzimuthLine").text)

@property
def last_line(self):
return int(self.element.find("lastAzimuthLine").text)
return np.uint16(self.element.find("lastAzimuthLine").text)

@property
def lines(self):
lines = self.element.find("line").text.split()
return np.array(lines).astype(int)
return np.array(lines).astype(np.uint16)

@property
def lut(self):
lut = self.element.find("noiseAzimuthLut").text.split()
return np.array(lut).astype(float)
return np.array(lut, dtype=np.float32)


class XMLArray:
Expand All @@ -487,7 +486,7 @@ def _read_xml_array(self):
new_x = elt.find("pixel").text.split()
y += [int(elt.find("line").text)] * len(new_x)
x += [int(val) for val in new_x]
data += [float(val)
data += [np.float32(val)
for val in elt.find(self.element_tag).text.split()]

return np.asarray(data), (x, y)
Expand Down Expand Up @@ -519,17 +518,17 @@ def interpolate_xarray_linear(xpoints, ypoints, values, shape, chunks=CHUNK_SIZE
else:
vchunks, hchunks = chunks, chunks

points = _ndim_coords_from_arrays(np.vstack((np.asarray(ypoints),
np.asarray(xpoints))).T)
points = _ndim_coords_from_arrays(np.vstack((np.asarray(ypoints, dtype=np.uint16),
np.asarray(xpoints, dtype=np.uint16))).T)

interpolator = LinearNDInterpolator(points, values)

grid_x, grid_y = da.meshgrid(da.arange(shape[1], chunks=hchunks),
da.arange(shape[0], chunks=vchunks))
grid_x, grid_y = da.meshgrid(da.arange(shape[1], chunks=hchunks, dtype=np.uint16),
da.arange(shape[0], chunks=vchunks, dtype=np.uint16))

# workaround for non-thread-safe first call of the interpolator:
interpolator((0, 0))
res = da.map_blocks(intp, grid_x, grid_y, interpolator=interpolator)
res = da.map_blocks(intp, grid_x, grid_y, interpolator=interpolator).astype(values.dtype)

return DataArray(res, dims=("y", "x"))

Expand Down Expand Up @@ -617,7 +616,7 @@ def _calibrate_and_denoise(self, data, key):
def _get_digital_number(self, data):
"""Get the digital numbers (uncalibrated data)."""
data = data.where(data > 0)
data = data.astype(np.float64)
data = data.astype(np.float32)
dn = data * data
return dn

Expand Down Expand Up @@ -675,8 +674,8 @@ def get_gcps(self):
for feature in gcps["features"]]
gcp_array = np.array(gcp_list)

ypoints = np.unique(gcp_array[:, 0])
xpoints = np.unique(gcp_array[:, 1])
ypoints = np.unique(gcp_array[:, 0]).astype(np.uint16)
xpoints = np.unique(gcp_array[:, 1]).astype(np.uint16)

gcp_lons = gcp_array[:, 2].reshape(ypoints.shape[0], xpoints.shape[0])
gcp_lats = gcp_array[:, 3].reshape(ypoints.shape[0], xpoints.shape[0])
Expand All @@ -686,6 +685,13 @@ def get_gcps(self):

return (xpoints, ypoints), (gcp_lons, gcp_lats, gcp_alts), (rio_gcps, crs)

def get_bounding_box(self):
"""Get the bounding box for the data coverage."""
(xpoints, ypoints), (gcp_lons, gcp_lats, gcp_alts), (rio_gcps, crs) = self.get_gcps()
bblons = np.hstack((gcp_lons[0, :-1], gcp_lons[:-1, -1], gcp_lons[-1, :1:-1], gcp_lons[:1:-1, 0]))
bblats = np.hstack((gcp_lats[0, :-1], gcp_lats[:-1, -1], gcp_lats[-1, :1:-1], gcp_lats[:1:-1, 0]))
return bblons.tolist(), bblats.tolist()

@property
def start_time(self):
"""Get the start time."""
Expand Down Expand Up @@ -733,7 +739,8 @@ def load(self, dataset_keys, **kwargs):
gcps = get_gcps_from_array(val)
from pyresample.future.geometry import SwathDefinition
val.attrs["area"] = SwathDefinition(lonlats["longitude"], lonlats["latitude"],
attrs=dict(gcps=gcps))
attrs=dict(gcps=gcps,
bounding_box=handler.get_bounding_box()))
datasets[key] = val
continue
return datasets
Expand Down
35 changes: 29 additions & 6 deletions satpy/tests/reader_tests/test_sar_c_safe.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,22 +215,28 @@ def test_read_calibrated_natural(self, measurement_filehandler):
calibration = Calibration.sigma_nought
xarr = measurement_filehandler.get_dataset(DataQuery(name="measurement", polarization="vv",
calibration=calibration, quantity="natural"), info=dict())
expected = np.array([[np.nan, 0.02707529], [2.55858416, 3.27611055]])
expected = np.array([[np.nan, 0.02707529], [2.55858416, 3.27611055]], dtype=np.float32)
np.testing.assert_allclose(xarr.values[:2, :2], expected, rtol=2e-7)
assert xarr.dtype == np.float32
assert xarr.compute().dtype == np.float32

def test_read_calibrated_dB(self, measurement_filehandler):
"""Test the calibration routines."""
calibration = Calibration.sigma_nought
xarr = measurement_filehandler.get_dataset(DataQuery(name="measurement", polarization="vv",
calibration=calibration, quantity="dB"), info=dict())
expected = np.array([[np.nan, -15.674268], [4.079997, 5.153585]])
np.testing.assert_allclose(xarr.values[:2, :2], expected)
expected = np.array([[np.nan, -15.674268], [4.079997, 5.153585]], dtype=np.float32)
np.testing.assert_allclose(xarr.values[:2, :2], expected, rtol=1e-6)
assert xarr.dtype == np.float32
assert xarr.compute().dtype == np.float32

def test_read_lon_lats(self, measurement_filehandler):
"""Test reading lons and lats."""
query = DataQuery(name="longitude", polarization="vv")
xarr = measurement_filehandler.get_dataset(query, info=dict())
np.testing.assert_allclose(xarr.values, expected_longitudes)
assert xarr.dtype == np.float64
assert xarr.compute().dtype == np.float64


annotation_xml = b"""<?xml version="1.0" encoding="UTF-8"?>
Expand Down Expand Up @@ -702,6 +708,8 @@ def test_get_noise_dataset(self, noise_filehandler):
query = DataQuery(name="noise", polarization="vv")
res = noise_filehandler.get_dataset(query, {})
np.testing.assert_allclose(res, self.expected_azimuth_noise * self.expected_range_noise)
assert res.dtype == np.float32
assert res.compute().dtype == np.float32

def test_get_noise_dataset_has_right_chunk_size(self, noise_filehandler):
"""Test using get_dataset for the noise has right chunk size in result."""
Expand All @@ -724,31 +732,40 @@ def test_dn_calibration_array(self, calibration_filehandler):
expected_dn = np.ones((10, 10)) * 1087
res = calibration_filehandler.get_calibration(Calibration.dn, chunks=5)
np.testing.assert_allclose(res, expected_dn)
assert res.dtype == np.float32
assert res.compute().dtype == np.float32

def test_beta_calibration_array(self, calibration_filehandler):
"""Test reading the beta calibration array."""
expected_beta = np.ones((10, 10)) * 1087
res = calibration_filehandler.get_calibration(Calibration.beta_nought, chunks=5)
np.testing.assert_allclose(res, expected_beta)
assert res.dtype == np.float32
assert res.compute().dtype == np.float32

def test_sigma_calibration_array(self, calibration_filehandler):
"""Test reading the sigma calibration array."""
expected_sigma = np.array([[1894.274, 1841.4335, 1788.593, 1554.4165, 1320.24, 1299.104,
1277.968, 1277.968, 1277.968, 1277.968]]) * np.ones((10, 1))
res = calibration_filehandler.get_calibration(Calibration.sigma_nought, chunks=5)
np.testing.assert_allclose(res, expected_sigma)

assert res.dtype == np.float32
assert res.compute().dtype == np.float32

def test_gamma_calibration_array(self, calibration_filehandler):
"""Test reading the gamma calibration array."""
res = calibration_filehandler.get_calibration(Calibration.gamma, chunks=5)
np.testing.assert_allclose(res, self.expected_gamma)
assert res.dtype == np.float32
assert res.compute().dtype == np.float32

def test_get_calibration_dataset(self, calibration_filehandler):
"""Test using get_dataset for the calibration."""
query = DataQuery(name="gamma", polarization="vv")
res = calibration_filehandler.get_dataset(query, {})
np.testing.assert_allclose(res, self.expected_gamma)
assert res.dtype == np.float32
assert res.compute().dtype == np.float32

def test_get_calibration_dataset_has_right_chunk_size(self, calibration_filehandler):
"""Test using get_dataset for the calibration yields array with right chunksize."""
Expand All @@ -762,13 +779,16 @@ def test_get_calibration_constant(self, calibration_filehandler):
query = DataQuery(name="calibration_constant", polarization="vv")
res = calibration_filehandler.get_dataset(query, {})
assert res == 1
assert type(res) is np.float32


def test_incidence_angle(annotation_filehandler):
"""Test reading the incidence angle in an annotation file."""
query = DataQuery(name="incidence_angle", polarization="vv")
res = annotation_filehandler.get_dataset(query, {})
np.testing.assert_allclose(res, 19.18318046)
assert res.dtype == np.float32
assert res.compute().dtype == np.float32


def test_reading_from_reader(measurement_file, calibration_file, noise_file, annotation_file):
Expand All @@ -787,7 +807,9 @@ def test_reading_from_reader(measurement_file, calibration_file, noise_file, ann
array = dataset_dict["measurement"]
np.testing.assert_allclose(array.attrs["area"].lons, expected_longitudes)
expected_db = np.array([[np.nan, -15.674268], [4.079997, 5.153585]])
np.testing.assert_allclose(array.values[:2, :2], expected_db)
np.testing.assert_allclose(array.values[:2, :2], expected_db, rtol=1e-6)
assert array.dtype == np.float32
assert array.compute().dtype == np.float32


def test_filename_filtering_from_reader(measurement_file, calibration_file, noise_file, annotation_file, tmp_path):
Expand All @@ -814,7 +836,7 @@ def test_filename_filtering_from_reader(measurement_file, calibration_file, nois
pytest.fail(str(err))


def test_swath_def_contains_gcps(measurement_file, calibration_file, noise_file, annotation_file):
def test_swath_def_contains_gcps_and_bounding_box(measurement_file, calibration_file, noise_file, annotation_file):
"""Test reading using the reader defined in the config."""
with open(Path(PACKAGE_CONFIG_PATH) / "readers" / "sar-c_safe.yaml") as fd:
config = yaml.load(fd, Loader=yaml.UnsafeLoader)
Expand All @@ -829,3 +851,4 @@ def test_swath_def_contains_gcps(measurement_file, calibration_file, noise_file,
dataset_dict = reader.load([query])
array = dataset_dict["measurement"]
assert array.attrs["area"].attrs["gcps"] is not None
assert array.attrs["area"].attrs["bounding_box"] is not None

0 comments on commit 9ebda32

Please sign in to comment.