Skip to content

Commit

Permalink
Implement local caching for WMTS requests (#2316)
Browse files Browse the repository at this point in the history
Add a local filesystem cache to WMTS images similar to the GoogleTiles cache. This speeds up
potentially slow and duplicate requests when dealing with image files from remote sources.
  • Loading branch information
dnowacki-usgs authored Jan 3, 2025
1 parent 113be8e commit e9238b6
Show file tree
Hide file tree
Showing 3 changed files with 149 additions and 17 deletions.
83 changes: 68 additions & 15 deletions lib/cartopy/io/ogc_clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import collections
import io
import math
import os
from pathlib import Path
from urllib.parse import urlparse
import warnings
import weakref
Expand All @@ -27,6 +29,8 @@
from PIL import Image
import shapely.geometry as sgeom

import cartopy


try:
import owslib.util
Expand Down Expand Up @@ -357,7 +361,7 @@ class WMTSRasterSource(RasterSource):
"""

def __init__(self, wmts, layer_name, gettile_extra_kwargs=None):
def __init__(self, wmts, layer_name, gettile_extra_kwargs=None, cache=False):
"""
Parameters
----------
Expand All @@ -368,6 +372,9 @@ def __init__(self, wmts, layer_name, gettile_extra_kwargs=None):
gettile_extra_kwargs: dict, optional
Extra keywords (e.g. time) to pass through to the
service's gettile method.
cache : bool or str, optional
If True, the default cache directory is used. If False, no cache is
used. If a string, the string is used as the path to the cache.
"""
if WebMapService is None:
Expand Down Expand Up @@ -397,6 +404,18 @@ def __init__(self, wmts, layer_name, gettile_extra_kwargs=None):

self._matrix_set_name_map = {}

# Enable a cache mechanism when cache is equal to True or to a path.
self._default_cache = False
if cache is True:
self._default_cache = True
self.cache_path = Path(cartopy.config["cache_dir"])
elif cache is False:
self.cache_path = None
else:
self.cache_path = Path(cache)
self.cache = set({})
self._load_cache()

def _matrix_set_name(self, target_projection):
key = id(target_projection)
matrix_set_name = self._matrix_set_name_map.get(key)
Expand Down Expand Up @@ -510,6 +529,23 @@ def fetch_raster(self, projection, extent, target_resolution):

return located_images

@property
def _cache_dir(self):
"""Return the name of the cache directory"""
return self.cache_path / self.__class__.__name__

def _load_cache(self):
"""Load the cache"""
if self.cache_path is not None:
cache_dir = self._cache_dir
if not cache_dir.exists():
os.makedirs(cache_dir)
if self._default_cache:
warnings.warn(
'Cartopy created the following directory to cache '
f'WMTSRasterSource tiles: {cache_dir}')
self.cache = self.cache.union(set(cache_dir.iterdir()))

def _choose_matrix(self, tile_matrices, meters_per_unit, max_pixel_span):
# Get the tile matrices in order of increasing resolution.
tile_matrices = sorted(tile_matrices,
Expand Down Expand Up @@ -642,21 +678,38 @@ def _wmts_images(self, wmts, layer, matrix_set_name, extent,
# Get the tile's Image from the cache if possible.
img_key = (row, col)
img = image_cache.get(img_key)

if img is None:
try:
tile = wmts.gettile(
layer=layer.id,
tilematrixset=matrix_set_name,
tilematrix=str(tile_matrix_id),
row=str(row), column=str(col),
**self.gettile_extra_kwargs)
except owslib.util.ServiceException as exception:
if ('TileOutOfRange' in exception.message and
ignore_out_of_range):
continue
raise exception
img = Image.open(io.BytesIO(tile.read()))
image_cache[img_key] = img
# Try it from disk cache
if self.cache_path is not None:
filename = f"{img_key[0]}_{img_key[1]}.npy"
cached_file = self._cache_dir / filename
else:
filename = None
cached_file = None

if cached_file in self.cache:
img = Image.fromarray(np.load(cached_file, allow_pickle=False))
else:
try:
tile = wmts.gettile(
layer=layer.id,
tilematrixset=matrix_set_name,
tilematrix=str(tile_matrix_id),
row=str(row), column=str(col),
**self.gettile_extra_kwargs)
except owslib.util.ServiceException as exception:
if ('TileOutOfRange' in exception.message and
ignore_out_of_range):
continue
raise exception
img = Image.open(io.BytesIO(tile.read()))
image_cache[img_key] = img
# save image to local cache
if self.cache_path is not None:
np.save(cached_file, img, allow_pickle=False)
self.cache.add(filename)

if big_img is None:
size = (img.size[0] * n_cols, img.size[1] * n_rows)
big_img = Image.new('RGBA', size, (255, 255, 255, 255))
Expand Down
4 changes: 2 additions & 2 deletions lib/cartopy/mpl/geoaxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2224,7 +2224,7 @@ def streamplot(self, x, y, u, v, **kwargs):
sp = super().streamplot(x, y, u, v, **kwargs)
return sp

def add_wmts(self, wmts, layer_name, wmts_kwargs=None, **kwargs):
def add_wmts(self, wmts, layer_name, wmts_kwargs=None, cache=False, **kwargs):
"""
Add the specified WMTS layer to the axes.
Expand All @@ -2249,7 +2249,7 @@ def add_wmts(self, wmts, layer_name, wmts_kwargs=None, **kwargs):
"""
from cartopy.io.ogc_clients import WMTSRasterSource
wmts = WMTSRasterSource(wmts, layer_name,
gettile_extra_kwargs=wmts_kwargs)
gettile_extra_kwargs=wmts_kwargs, cache=cache)
return self.add_raster(wmts, **kwargs)

def add_wms(self, wms, layers, wms_kwargs=None, **kwargs):
Expand Down
79 changes: 79 additions & 0 deletions lib/cartopy/tests/test_img_tiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,11 @@
from cartopy import config
import cartopy.crs as ccrs
import cartopy.io.img_tiles as cimgt
import cartopy.io.ogc_clients as ogc


RESOLUTION = (30, 30)

#: Maps Google tile coordinates to native mercator coordinates as defined
#: by https://goo.gl/pgJi.
KNOWN_EXTENTS = {(0, 0, 0): (-20037508.342789244, 20037508.342789244,
Expand Down Expand Up @@ -328,6 +331,82 @@ def test_azuremaps_get_image():
assert extent1 == extent2


@pytest.mark.network
@pytest.mark.parametrize('cache_dir', ["tmpdir", True, False])
@pytest.mark.skipif(not ogc._OWSLIB_AVAILABLE, reason='OWSLib is unavailable.')
def test_wmts_cache(cache_dir, tmp_path):
if cache_dir == "tmpdir":
tmpdir_str = str(tmp_path)
else:
tmpdir_str = cache_dir

if cache_dir is True:
config["cache_dir"] = str(tmp_path)

# URI = 'https://map1c.vis.earthdata.nasa.gov/wmts-geo/wmts.cgi'
# layer_name = 'VIIRS_CityLights_2012'
URI = 'https://basemap.nationalmap.gov/arcgis/rest/services/USGSImageryOnly/MapServer/WMTS/1.0.0/WMTSCapabilities.xml'
layer_name='USGSImageryOnly'
projection = ccrs.PlateCarree()

# Fetch tiles and save them in the cache
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter('always')
source = ogc.WMTSRasterSource(URI, layer_name, cache=tmpdir_str)
extent = [-10, 10, 40, 60]
located_image, = source.fetch_raster(projection, extent,
RESOLUTION)

# Do not check the result if the cache is disabled
if cache_dir is False:
assert source.cache_path is None
return

# Check that the warning is properly raised (only when cache is True)
if cache_dir is True:
assert len(w) == 1
else:
assert len(w) == 0

# Define expected results
x_y_f_h = [
(1, 1, '1_1.npy', '0de548bd47e4579ae0500da6ceeb08e7'),
(1, 2, '1_2.npy', '4beebcd3e4408af5accb440d7b4c8933'),
]

# Check the results
cache_dir_res = source.cache_path / "WMTSRasterSource"
files = list(cache_dir_res.iterdir())
hashes = {
f:
hashlib.md5(
np.load(cache_dir_res / f, allow_pickle=True).data
).hexdigest()
for f in files
}
assert sorted(files) == [cache_dir_res / f for x, y, f, h in x_y_f_h]
assert set(files) == set([cache_dir_res / c for c in source.cache])

assert sorted(hashes.values()) == sorted(
h for x, y, f, h in x_y_f_h
)

# Update images in cache (all white)
for f in files:
filename = cache_dir_res / f
img = np.load(filename, allow_pickle=True)
img.fill(255)
np.save(filename, img, allow_pickle=True)

wmts_cache = ogc.WMTSRasterSource(URI, layer_name, cache=tmpdir_str)
located_image_cache, = wmts_cache.fetch_raster(projection, extent,
RESOLUTION)

# Check that the new fetch_raster() call used cached images
assert wmts_cache.cache == set([cache_dir_res / c for c in source.cache])
assert (np.array(located_image_cache.image) == 255).all()


@pytest.mark.network
@pytest.mark.parametrize('cache_dir', ["tmpdir", True, False])
def test_cache(cache_dir, tmp_path):
Expand Down

0 comments on commit e9238b6

Please sign in to comment.