Skip to content

Commit

Permalink
Reproject method for ImageData (#782)
Browse files Browse the repository at this point in the history
* Add reproject method to ImageData class with tests

* Add reproject method for ImageData objects in CHANGES.md

* Remove unused rasterio import and clean up test cases for reproject method

* lint

* fix and update tests

* relaxe rasterio version limit

* update changelog

---------

Co-authored-by: vincentsarago <[email protected]>
  • Loading branch information
emmanuelmathot and vincentsarago authored Jan 28, 2025
1 parent 3fbeaf3 commit 8a47991
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 4 deletions.
11 changes: 11 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,16 @@
# Unreleased (TBD)

* update rasterio dependency to `>=1.4.0`

* add `reproject` method for `ImageData` objects (author @emmanuelmathot, https://github.com/cogeotiff/rio-tiler/pull/782)

```python
from rio_tiler.models import ImageData

img = ImageData(numpy.zeros((3, 256, 256), crs=CRS.from_epsg(4326), dtype="uint8"))
img_3857 = img.reproject("epsg:3857")
```

* add `indexes` parameter for `XarrayReader` methods. As for Rasterio, the indexes values start at `1`.

```python
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ dependencies = [
"morecantile>=5.0,<7.0",
"pydantic~=2.0",
"pystac>=0.5.4",
"rasterio>=1.3.0",
"rasterio>=1.4.0",
"color-operations",
"typing-extensions",
"importlib_resources>=1.1.0; python_version < '3.9'",
Expand Down
48 changes: 45 additions & 3 deletions rio_tiler/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@
from rasterio.coords import BoundingBox
from rasterio.crs import CRS
from rasterio.dtypes import dtype_ranges
from rasterio.enums import ColorInterp
from rasterio.enums import ColorInterp, Resampling
from rasterio.errors import NotGeoreferencedWarning
from rasterio.features import rasterize
from rasterio.io import MemoryFile
from rasterio.plot import reshape_as_image
from rasterio.transform import from_bounds
from rasterio.warp import transform_geom
from rasterio.transform import array_bounds, from_bounds
from rasterio.warp import calculate_default_transform, reproject, transform_geom
from typing_extensions import Self

from rio_tiler.colormap import apply_cmap
Expand All @@ -34,6 +34,7 @@
IntervalTuple,
NumType,
RIOResampling,
WarpResampling,
)
from rio_tiler.utils import (
_validate_shape_input,
Expand Down Expand Up @@ -786,3 +787,44 @@ def get_coverage_array(
).astype("float32")

return cover_array.sum(-1).sum(1) / (cover_scale**2)

def reproject(
self,
dst_crs: CRS,
resolution: Optional[Tuple[float, float]] = None,
reproject_method: WarpResampling = "nearest",
) -> "ImageData":
"""Reproject data and mask."""
dst_transform, w, h = calculate_default_transform(
self.crs,
dst_crs,
self.width,
self.height,
*self.bounds,
resolution=resolution,
)

destination = numpy.ma.masked_array(
numpy.zeros((self.count, h, w), dtype=self.array.dtype),
)
destination, _ = reproject(
self.array,
destination,
src_transform=self.transform,
src_crs=self.crs,
dst_transform=dst_transform,
dst_crs=dst_crs,
resampling=Resampling[reproject_method],
)

bounds = array_bounds(h, w, dst_transform)

return ImageData(
destination,
assets=self.assets,
crs=dst_crs,
bounds=bounds,
band_names=self.band_names,
metadata=self.metadata,
dataset_statistics=self.dataset_statistics,
)
68 changes: 68 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,3 +466,71 @@ def test_image_encoding_error():
"""Test ImageData error when using bad data array shape."""
with pytest.raises(InvalidFormat):
ImageData(numpy.zeros((5, 256, 256), dtype="uint8")).render(img_format="PNG")


def test_image_reproject():
"""Test basic reproject functionality."""
data = numpy.zeros((1, 256, 256), dtype="uint8")
data[0:256, 0:256] = 1
mask = numpy.zeros((1, 256, 256), dtype="bool")
mask[0:100, 0:100] = True

# Create test image with WGS84 CRS
src_crs = CRS.from_epsg(4326)
img = ImageData(
numpy.ma.MaskedArray(data=data, mask=mask),
crs=src_crs,
bounds=(-95, 43, -92, 45),
metadata={"test": "value"},
band_names=["band1"],
)

# Test re-projection to Web Mercator
dst_crs = CRS.from_epsg(3857)

reprojected = img.reproject(dst_crs)
assert reprojected.crs == dst_crs
assert reprojected.count == 1
assert reprojected.width != 256
assert reprojected.height != 256
assert reprojected.array[0, 0, 0].data == 0
assert reprojected.array.data[0, -10, -10] == 1
assert reprojected.array.mask.shape[0] == 1
assert reprojected.array.mask[0, 0, 0]
assert not reprojected.array.mask[0, -10, -10]
assert reprojected.metadata == img.metadata
assert reprojected.band_names == img.band_names

# Test no re-projection when CRS is the same
same_crs = img.reproject(src_crs)
assert same_crs.crs == src_crs
assert same_crs.transform == img.transform
numpy.testing.assert_array_equal(same_crs.array, img.array)

# Test with different resampling method
reprojected_bilinear = img.reproject(dst_crs, reproject_method="bilinear")
with numpy.testing.assert_raises(AssertionError):
numpy.testing.assert_array_equal(reprojected_bilinear.array, img.array)

# With MultiBands
data = numpy.zeros((3, 256, 256), dtype="uint8")
data[:, 0:256, 0:256] = 1
mask = numpy.zeros((3, 256, 256), dtype="bool")
mask[:, 0:100, 0:100] = True

img = ImageData(
numpy.ma.MaskedArray(data=data, mask=mask),
crs=src_crs,
bounds=(-95, 43, -92, 45),
)

reprojected = img.reproject(dst_crs)
assert reprojected.crs == dst_crs
assert reprojected.count == 3
assert reprojected.width != 256
assert reprojected.height != 256
assert reprojected.array.data[:, 0, 0].tolist() == [0, 0, 0]
assert reprojected.array.data[:, -10, -10].tolist() == [1, 1, 1]
assert reprojected.array.mask.shape[0] == 3
assert reprojected.array.mask[:, 0, 0].tolist() == [True, True, True]
assert reprojected.array.mask[:, -10, -10].tolist() == [False, False, False]

0 comments on commit 8a47991

Please sign in to comment.