diff --git a/lib/cartopy/io/ogc_clients.py b/lib/cartopy/io/ogc_clients.py index 285f8222d..64016c669 100644 --- a/lib/cartopy/io/ogc_clients.py +++ b/lib/cartopy/io/ogc_clients.py @@ -18,6 +18,8 @@ import collections import io import math +import os +from pathlib import Path from urllib.parse import urlparse import warnings import weakref @@ -27,6 +29,8 @@ from PIL import Image import shapely.geometry as sgeom +import cartopy + try: import owslib.util @@ -357,7 +361,7 @@ class WMTSRasterSource(RasterSource): """ - def __init__(self, wmts, layer_name, gettile_extra_kwargs=None): + def __init__(self, wmts, layer_name, gettile_extra_kwargs=None, cache=False): """ Parameters ---------- @@ -368,6 +372,9 @@ def __init__(self, wmts, layer_name, gettile_extra_kwargs=None): gettile_extra_kwargs: dict, optional Extra keywords (e.g. time) to pass through to the service's gettile method. + cache : bool or str, optional + If True, the default cache directory is used. If False, no cache is + used. If a string, the string is used as the path to the cache. """ if WebMapService is None: @@ -397,6 +404,18 @@ def __init__(self, wmts, layer_name, gettile_extra_kwargs=None): self._matrix_set_name_map = {} + # Enable a cache mechanism when cache is equal to True or to a path. + self._default_cache = False + if cache is True: + self._default_cache = True + self.cache_path = Path(cartopy.config["cache_dir"]) + elif cache is False: + self.cache_path = None + else: + self.cache_path = Path(cache) + self.cache = set({}) + self._load_cache() + def _matrix_set_name(self, target_projection): key = id(target_projection) matrix_set_name = self._matrix_set_name_map.get(key) @@ -510,6 +529,23 @@ def fetch_raster(self, projection, extent, target_resolution): return located_images + @property + def _cache_dir(self): + """Return the name of the cache directory""" + return self.cache_path / self.__class__.__name__ + + def _load_cache(self): + """Load the cache""" + if self.cache_path is not None: + cache_dir = self._cache_dir + if not cache_dir.exists(): + os.makedirs(cache_dir) + if self._default_cache: + warnings.warn( + 'Cartopy created the following directory to cache ' + f'WMTSRasterSource tiles: {cache_dir}') + self.cache = self.cache.union(set(cache_dir.iterdir())) + def _choose_matrix(self, tile_matrices, meters_per_unit, max_pixel_span): # Get the tile matrices in order of increasing resolution. tile_matrices = sorted(tile_matrices, @@ -642,21 +678,38 @@ def _wmts_images(self, wmts, layer, matrix_set_name, extent, # Get the tile's Image from the cache if possible. img_key = (row, col) img = image_cache.get(img_key) + if img is None: - try: - tile = wmts.gettile( - layer=layer.id, - tilematrixset=matrix_set_name, - tilematrix=str(tile_matrix_id), - row=str(row), column=str(col), - **self.gettile_extra_kwargs) - except owslib.util.ServiceException as exception: - if ('TileOutOfRange' in exception.message and - ignore_out_of_range): - continue - raise exception - img = Image.open(io.BytesIO(tile.read())) - image_cache[img_key] = img + # Try it from disk cache + if self.cache_path is not None: + filename = f"{img_key[0]}_{img_key[1]}.npy" + cached_file = self._cache_dir / filename + else: + filename = None + cached_file = None + + if cached_file in self.cache: + img = Image.fromarray(np.load(cached_file, allow_pickle=False)) + else: + try: + tile = wmts.gettile( + layer=layer.id, + tilematrixset=matrix_set_name, + tilematrix=str(tile_matrix_id), + row=str(row), column=str(col), + **self.gettile_extra_kwargs) + except owslib.util.ServiceException as exception: + if ('TileOutOfRange' in exception.message and + ignore_out_of_range): + continue + raise exception + img = Image.open(io.BytesIO(tile.read())) + image_cache[img_key] = img + # save image to local cache + if self.cache_path is not None: + np.save(cached_file, img, allow_pickle=False) + self.cache.add(filename) + if big_img is None: size = (img.size[0] * n_cols, img.size[1] * n_rows) big_img = Image.new('RGBA', size, (255, 255, 255, 255)) diff --git a/lib/cartopy/mpl/geoaxes.py b/lib/cartopy/mpl/geoaxes.py index 85028ad2d..b58e594cf 100644 --- a/lib/cartopy/mpl/geoaxes.py +++ b/lib/cartopy/mpl/geoaxes.py @@ -2224,7 +2224,7 @@ def streamplot(self, x, y, u, v, **kwargs): sp = super().streamplot(x, y, u, v, **kwargs) return sp - def add_wmts(self, wmts, layer_name, wmts_kwargs=None, **kwargs): + def add_wmts(self, wmts, layer_name, wmts_kwargs=None, cache=False, **kwargs): """ Add the specified WMTS layer to the axes. @@ -2249,7 +2249,7 @@ def add_wmts(self, wmts, layer_name, wmts_kwargs=None, **kwargs): """ from cartopy.io.ogc_clients import WMTSRasterSource wmts = WMTSRasterSource(wmts, layer_name, - gettile_extra_kwargs=wmts_kwargs) + gettile_extra_kwargs=wmts_kwargs, cache=cache) return self.add_raster(wmts, **kwargs) def add_wms(self, wms, layers, wms_kwargs=None, **kwargs): diff --git a/lib/cartopy/tests/test_img_tiles.py b/lib/cartopy/tests/test_img_tiles.py index bc4354cd0..d92224ee7 100644 --- a/lib/cartopy/tests/test_img_tiles.py +++ b/lib/cartopy/tests/test_img_tiles.py @@ -16,8 +16,11 @@ from cartopy import config import cartopy.crs as ccrs import cartopy.io.img_tiles as cimgt +import cartopy.io.ogc_clients as ogc +RESOLUTION = (30, 30) + #: Maps Google tile coordinates to native mercator coordinates as defined #: by https://goo.gl/pgJi. KNOWN_EXTENTS = {(0, 0, 0): (-20037508.342789244, 20037508.342789244, @@ -328,6 +331,82 @@ def test_azuremaps_get_image(): assert extent1 == extent2 +@pytest.mark.network +@pytest.mark.parametrize('cache_dir', ["tmpdir", True, False]) +@pytest.mark.skipif(not ogc._OWSLIB_AVAILABLE, reason='OWSLib is unavailable.') +def test_wmts_cache(cache_dir, tmp_path): + if cache_dir == "tmpdir": + tmpdir_str = str(tmp_path) + else: + tmpdir_str = cache_dir + + if cache_dir is True: + config["cache_dir"] = str(tmp_path) + + # URI = 'https://map1c.vis.earthdata.nasa.gov/wmts-geo/wmts.cgi' + # layer_name = 'VIIRS_CityLights_2012' + URI = 'https://basemap.nationalmap.gov/arcgis/rest/services/USGSImageryOnly/MapServer/WMTS/1.0.0/WMTSCapabilities.xml' + layer_name='USGSImageryOnly' + projection = ccrs.PlateCarree() + + # Fetch tiles and save them in the cache + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter('always') + source = ogc.WMTSRasterSource(URI, layer_name, cache=tmpdir_str) + extent = [-10, 10, 40, 60] + located_image, = source.fetch_raster(projection, extent, + RESOLUTION) + + # Do not check the result if the cache is disabled + if cache_dir is False: + assert source.cache_path is None + return + + # Check that the warning is properly raised (only when cache is True) + if cache_dir is True: + assert len(w) == 1 + else: + assert len(w) == 0 + + # Define expected results + x_y_f_h = [ + (1, 1, '1_1.npy', '0de548bd47e4579ae0500da6ceeb08e7'), + (1, 2, '1_2.npy', '4beebcd3e4408af5accb440d7b4c8933'), + ] + + # Check the results + cache_dir_res = source.cache_path / "WMTSRasterSource" + files = list(cache_dir_res.iterdir()) + hashes = { + f: + hashlib.md5( + np.load(cache_dir_res / f, allow_pickle=True).data + ).hexdigest() + for f in files + } + assert sorted(files) == [cache_dir_res / f for x, y, f, h in x_y_f_h] + assert set(files) == set([cache_dir_res / c for c in source.cache]) + + assert sorted(hashes.values()) == sorted( + h for x, y, f, h in x_y_f_h + ) + + # Update images in cache (all white) + for f in files: + filename = cache_dir_res / f + img = np.load(filename, allow_pickle=True) + img.fill(255) + np.save(filename, img, allow_pickle=True) + + wmts_cache = ogc.WMTSRasterSource(URI, layer_name, cache=tmpdir_str) + located_image_cache, = wmts_cache.fetch_raster(projection, extent, + RESOLUTION) + + # Check that the new fetch_raster() call used cached images + assert wmts_cache.cache == set([cache_dir_res / c for c in source.cache]) + assert (np.array(located_image_cache.image) == 255).all() + + @pytest.mark.network @pytest.mark.parametrize('cache_dir', ["tmpdir", True, False]) def test_cache(cache_dir, tmp_path):