Skip to content

Commit

Permalink
Refactor to move vis param calculation into EELeafletTileLayer
Browse files Browse the repository at this point in the history
  • Loading branch information
aazuspan committed Oct 31, 2023
1 parent 2f7ee1e commit f2b3a6c
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 47 deletions.
78 changes: 78 additions & 0 deletions geemap/ee_tile_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import ee
import folium
import ipyleaflet
from functools import lru_cache

from . import common

Expand Down Expand Up @@ -139,6 +140,7 @@ def __init__(
shown (bool, optional): A flag indicating whether the layer should be on by default. Defaults to True.
opacity (float, optional): The layer's opacity represented as a number between 0 and 1. Defaults to 1.
"""
self._ee_object = ee_object
self.url_format = _get_tile_url_format(
ee_object, _validate_vis_params(vis_params)
)
Expand All @@ -151,3 +153,79 @@ def __init__(
max_zoom=24,
**kwargs,
)

@lru_cache()
def _calculate_vis_stats(self, *, bounds, bands):
"""Calculate stats used for visualization parameters.
Stats are calculated consistently with the Code Editor visualization parameters,
and are cached to avoid recomputing for the same bounds and bands.
Args:
bounds (ee.Geometry|ee.Feature|ee.FeatureCollection): The bounds to sample.
bands (tuple): The bands to sample.
Returns:
tuple: The minimum, maximum, standard deviation, and mean values across the
specified bands.
"""
stat_reducer = (ee.Reducer.minMax()
.combine(ee.Reducer.mean().unweighted(), sharedInputs=True)
.combine(ee.Reducer.stdDev(), sharedInputs=True))

stats = self._ee_object.select(bands).reduceRegion(
reducer=stat_reducer,
geometry=bounds,
bestEffort=True,
maxPixels=10_000,
crs="SR-ORG:6627",
scale=1,
).getInfo()

mins, maxs, stds, means = [
{v for k, v in stats.items() if k.endswith(stat) and v is not None}
for stat in ('_min', '_max', '_stdDev', '_mean')
]
if any(len(vals) == 0 for vals in (mins, maxs, stds, means)):
raise ValueError('No unmasked pixels were sampled.')

min_val = min(mins)
max_val = max(maxs)
std_dev = sum(stds) / len(stds)
mean = sum(means) / len(means)

return (min_val, max_val, std_dev, mean)

def calculate_vis_minmax(self, *, bounds, bands=None, percent=None, sigma=None):
"""Calculate the min and max clip values for visualization.
Args:
bounds (ee.Geometry|ee.Feature|ee.FeatureCollection): The bounds to sample.
bands (list, optional): The bands to sample. If None, all bands are used.
percent (float, optional): The percent to use when stretching.
sigma (float, optional): The number of standard deviations to use when
stretching.
Returns:
tuple: The minimum and maximum values to clip to.
"""
bands = self._ee_object.bandNames() if bands is None else tuple(bands)
try:
min_val, max_val, std, mean = self._calculate_vis_stats(
bounds=bounds, bands=bands
)
except ValueError:
return (0, 0)

if sigma is not None:
stretch_min = mean - sigma * std
stretch_max = mean + sigma * std
elif percent is not None:
x = (max_val - min_val) * (1 - percent)
stretch_min = min_val + x
stretch_max = max_val - x
else:
stretch_min = min_val
stretch_max = max_val

return (stretch_min, stretch_max)
57 changes: 10 additions & 47 deletions geemap/map_widgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1322,59 +1322,22 @@ def _value_stretch_changed(self, value):
self._stretch_button.disabled = True
self._value_range_slider.disabled = False

def _update_stretch(self, *args):
def _update_stretch(self, *_):
"""Calculate and set the range slider by applying stretch parameters."""
stretch_params = self._stretch_dropdown.value

min_val, max_val = self._calculate_stretch(**stretch_params)
self._value_range_slider.min = min_val
self._value_range_slider.max = max_val
self._value_range_slider.value = [min_val, max_val]

def _calculate_stretch(self, percent=None, sigma=None):
"""Calculate min and max stretch values for the raster image."""
(s, w), (n, e) = self._host_map.bounds
map_bbox = ee.Geometry.BBox(west=w, south=s, east=e, north=n)
vis_bands = list(set((b.value for b in self._bands_hbox.children)))

stat_reducer = (ee.Reducer.minMax()
.combine(ee.Reducer.mean().unweighted(), sharedInputs=True)
.combine(ee.Reducer.stdDev(), sharedInputs=True))

stats = self._ee_object.select(vis_bands).reduceRegion(
reducer=stat_reducer,
geometry=map_bbox,
bestEffort=True,
maxPixels=10_000,
crs="SR-ORG:6627",
scale=1,
).getInfo()

mins, maxs, stds, means = [
{v for k, v in stats.items() if k.endswith(stat) and v is not None}
for stat in ('_min', '_max', '_stdDev', '_mean')
]
if any(len(vals) == 0 for vals in (mins, maxs, stds, means)):
# No unmasked pixels were sampled
return (0, 0)

min_val = min(mins)
max_val = max(maxs)
std_dev = sum(stds) / len(stds)
mean = sum(means) / len(means)

if sigma is not None:
stretch_min = mean - sigma * std_dev
stretch_max = mean + sigma * std_dev
elif percent is not None:
x = (max_val - min_val) * (1 - percent)
stretch_min = min_val + x
stretch_max = max_val - x
else:
stretch_min = min_val
stretch_max = max_val
vis_bands = set((b.value for b in self._bands_hbox.children))
min_val, max_val = self._ee_layer.calculate_vis_minmax(
bounds=map_bbox,
bands=vis_bands,
**stretch_params
)

return (stretch_min, stretch_max)
self._value_range_slider.min = min_val
self._value_range_slider.max = max_val
self._value_range_slider.value = [min_val, max_val]

def _get_tool_layout(self, grayscale):
return [
Expand Down

0 comments on commit f2b3a6c

Please sign in to comment.