Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug when all classes are hidden #55

Merged
merged 3 commits into from
Mar 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 24 additions & 21 deletions sankee/datasets.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from __future__ import annotations

from warnings import warn

import ee
import pandas as pd
import plotly.graph_objects as go

from sankee import themes
from sankee.plotting import sankify
from sankee.plotting import SankeyPlot, sankify


class Dataset:
Expand Down Expand Up @@ -76,14 +77,14 @@ def get_year(self, year: int) -> ee.Image:
)

img = self.collection.filterDate(str(year), str(year + 1)).first()
img = self.set_visualization_properties(img)
img = self._set_visualization_properties(img)

if self.nodata is not None:
img = img.updateMask(img.neq(self.nodata))

return img.select(self.band)

def set_visualization_properties(self, image: ee.Image) -> ee.Image:
def _set_visualization_properties(self, image: ee.Image) -> ee.Image:
"""Set the properties used by Earth Engine to automatically assign a palette to an image
from this dataset."""
return image.set(
Expand All @@ -93,7 +94,7 @@ def set_visualization_properties(self, image: ee.Image) -> ee.Image:
[c.replace("#", "") for c in self.palette.values()],
)

def list_years(self) -> ee.List:
def _list_years(self) -> ee.List:
"""Get an ee.List of all years in the collection."""
return (
self.collection.aggregate_array("system:time_start")
Expand All @@ -113,7 +114,7 @@ def sankify(
exclude: None = None,
label_type: str = "class",
theme: str | themes.Theme = themes.DEFAULT,
) -> go.Figure:
) -> SankeyPlot:
"""
Generate an interactive Sankey plot showing land cover change over time from a series of
years in the dataset.
Expand All @@ -124,11 +125,6 @@ def sankify(
The years to include in the plot. Select at least two unique years.
region : ee.Geometry
A region to generate samples within. The region must overlap all images.
exclude : list[int], default None
An optional list of pixel values to exclude from the plot. Excluded values must be raw
pixel values rather than class labels. This can be helpful if the region is dominated by
one or more unchanging classes and the goal is to visualize changes in smaller classes.
No-data classes are always excluded automatically.
max_classes : int, default None
If a value is provided, small classes will be removed until max_classes remain. Class
size is calculated based on total times sampled in the time series.
Expand All @@ -143,19 +139,26 @@ def sankify(
projection.
seed : int, default 0
The seed value used to generate repeatable results during random sampling.
exclude : None
Unused parameter that will be removed in a future release.
label_type : str, default "class"
The type of label to display for each link, one of "class", "percent", or "count".
Selecting "class" will use the class label, "percent" will use the proportion of
sampled pixels in each class, and "count" will use the number of sampled pixels in each
class.
theme : str or Theme
The theme to apply to the Sankey diagram. Can be the name of a built-in theme
(e.g. "d3") or a custom `sankee.Theme` object.

Returns
-------
plotly.graph_objs._figure.Figure
An interactive Sankey plot.
SankeyPlot
An interactive Sankey plot widget.
"""
if exclude is not None:
warn(
"The `exclude` parameter is unused and will be removed in a future release.",
DeprecationWarning,
stacklevel=2,
)
if len(years) < 2:
raise ValueError("Select at least two years.")
if len(set(years)) != len(years):
Expand Down Expand Up @@ -187,7 +190,7 @@ def sankify(
)


class LCMS_Dataset(Dataset):
class _LCMS_Dataset(Dataset):
def get_year(self, year: int) -> ee.Image:
"""Get one year's image from the dataset. LCMS splits up each year into two images: CONUS
and SEAK. This merges those into a single image."""
Expand All @@ -203,7 +206,7 @@ def get_year(self, year: int) -> ee.Image:
return merged


class CCAP_Dataset(Dataset):
class _CCAP_Dataset(Dataset):
def get_year(self, year: int) -> ee.Image:
"""Get one year's image from the dataset. C-CAP splits up each year into multiple images,
so merge those and set the class value and palette metadata to allow automatic
Expand All @@ -224,12 +227,12 @@ def get_year(self, year: int) -> ee.Image:
.setDefaultProjection("EPSG:5070")
)

img = self.set_visualization_properties(img)
img = self._set_visualization_properties(img)

return img


LCMS_LU = LCMS_Dataset(
LCMS_LU = _LCMS_Dataset(
name="LCMS LU - Land Change Monitoring System Land Use",
id="USFS/GTAC/LCMS/v2022-8",
band="Land_Use",
Expand All @@ -256,7 +259,7 @@ def get_year(self, year: int) -> ee.Image:
)

# https://developers.google.com/earth-engine/datasets/catalog/USFS_GTAC_LCMS_v2020-5
LCMS_LC = LCMS_Dataset(
LCMS_LC = _LCMS_Dataset(
name="LCMS LC - Land Change Monitoring System Land Cover",
id="USFS/GTAC/LCMS/v2022-8",
band="Land_Cover",
Expand Down Expand Up @@ -538,7 +541,7 @@ def get_year(self, year: int) -> ee.Image:
)

# https://samapriya.github.io/awesome-gee-community-datasets/projects/ccap_mlc/
CCAP_LC30 = CCAP_Dataset(
CCAP_LC30 = _CCAP_Dataset(
name="C-CAP - NOAA Coastal Change Analysis Program 30m",
id="projects/sat-io/open-datasets/NOAA/ccap_30m",
band="b1",
Expand Down
83 changes: 44 additions & 39 deletions sankee/plotting.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from collections import namedtuple
from typing import Literal

import ee
import ipywidgets as widgets
Expand Down Expand Up @@ -36,9 +37,9 @@ def sankify(
title: None | str = None,
scale: None | int = None,
seed: int = 0,
label_type: None | str = "class",
label_type: None | Literal["class", "percent", "count"] = "class",
theme: str | themes.Theme = "default",
) -> go.Figure:
) -> SankeyPlot:
"""
Generate an interactive Sankey plot showing land cover change over time from a series of images.

Expand Down Expand Up @@ -77,18 +78,18 @@ def sankify(
seed : int, default 0
The seed value used to generate repeatable results during random sampling.
label_type : str, default "class"
The type of label to display for each link, one of "class", "percent", "count", or False.
The type of label to display for each link, one of "class", "percent", "count", or None.
Selecting "class" will use the class label, "percent" will use the proportion of sampled
pixels in each class, and "count" will use the number of sampled pixels in each class.
False will disable link labels.
None will disable link labels.
theme : str or Theme
The theme to apply to the Sankey diagram. Can be the name of a built-in theme (e.g. "d3") or
a custom `sankee.Theme` object.

Returns
-------
plotly.graph_objs._figure.Figure
An interactive Sankey plot.
SankeyPlot
An interactive Sankey plot widget.
"""
if region is None:
region = image_list[0].geometry()
Expand Down Expand Up @@ -126,12 +127,13 @@ def sankify(
class SankeyPlot(widgets.DOMWidget):
def __init__(
self,
*,
data: pd.DataFrame,
labels: dict[int, str],
palette: dict[int, str],
title: str,
samples: ee.FeatureCollection,
label_type: str,
label_type: None | Literal["class", "percent", "count"],
theme: str | themes.Theme,
):
self.data = data
Expand All @@ -140,15 +142,15 @@ def __init__(
self.title = title
self.samples = samples
self.label_type = label_type
self.theme = theme
self.theme = theme if isinstance(theme, themes.Theme) else themes.load_theme(theme)

self.hide = []
# Initialized by `self.generate_plot`
self.df = None
self.plot = self.generate_plot()
self.gui = self.generate_gui()
self.plot = self._generate_figurewidget()
self.gui = self._generate_gui()

def get_sorted_classes(self) -> pd.Series:
def _get_sorted_classes(self) -> pd.Series:
"""Return all unique class values, sorted by the total number of observations."""
start_count = (
self.df.loc[:, ["source", "total"]]
Expand All @@ -168,11 +170,11 @@ def get_sorted_classes(self) -> pd.Series:

return total_count.sort_values(by="count", ascending=False)["class"].reset_index(drop=True)

def get_active_classes(self) -> pd.Series:
def _get_active_classes(self) -> pd.Series:
"""Return all unique active, visibile class values after filtering."""
return self.df[["source", "target"]].melt().value.unique()

def generate_plot_parameters(self) -> SankeyParameters:
def _generate_plot_parameters(self) -> SankeyParameters:
"""Generate Sankey plot parameters from a formatted, cleaned dataframe"""
df = self.df.copy()

Expand Down Expand Up @@ -226,7 +228,7 @@ def generate_plot_parameters(self) -> SankeyParameters:
all_classes["label"] = ""
else:
raise ValueError(
"Invalid label_type. Choose from 'class', 'percent', 'count', or False."
"Invalid label_type. Choose from 'class', 'percent', 'count', or None."
)

return SankeyParameters(
Expand All @@ -240,7 +242,7 @@ def generate_plot_parameters(self) -> SankeyParameters:
value=df.changed,
)

def generate_dataframe(self) -> pd.DataFrame:
def _generate_dataframe(self) -> pd.DataFrame:
"""Convert raw sampling data to a formatted dataframe"""
data = self.data.copy()

Expand Down Expand Up @@ -299,14 +301,15 @@ def _model_id(self):
return self.gui._model_id

def update_layout(self, *args, **kwargs):
"""Pass layout changes to the plot. This is primarily kept for compatibility with geemap."""
"""Pass layout changes to the plot."""
# This is primarily kept for compatibility with geemap
self.plot.update_layout(*args, **kwargs)

def generate_gui(self):
def _generate_gui(self):
BUTTON_HEIGHT = "24px"
BUTTON_WIDTH = "24px"

unique_classes = self.get_sorted_classes()
unique_classes = self._get_sorted_classes()

def toggle_button(button):
button.toggle()
Expand All @@ -322,13 +325,13 @@ def toggle_button(button):
update_plot()

def update_plot():
"""Swap new data into the plot"""
new_plot = self.generate_plot()
self.plot.data[0].link = new_plot.data[0].link
self.plot.data[0].node = new_plot.data[0].node
"""Swap new data into the plot."""
new_sankey = self._generate_sankey()
self.plot.data[0].link = new_sankey.link
self.plot.data[0].node = new_sankey.node

buttons = []
active_classes = self.get_active_classes()
active_classes = self._get_active_classes()
for i in unique_classes:
label = self.labels[i]
on_color = self.palette[i]
Expand Down Expand Up @@ -373,18 +376,20 @@ def reset_plot(_):

return gui

def generate_plot(self) -> go.Figure:
self.df = self.generate_dataframe()
params = self.generate_plot_parameters()
def _generate_sankey(self) -> go.Figure:
"""Generate the Sankey plot based on the currently visible classes."""
self.df = self._generate_dataframe()
# Explicitly return an empty Sankey plot if all classes are hidden to avoid widget update
# errors.
if len(self.df) == 0:
return go.Sankey()

theme = (
self.theme if isinstance(self.theme, themes.Theme) else themes.load_theme(self.theme)
)
params = self._generate_plot_parameters()

node_kwargs = dict(
customdata=params.node_labels,
hovertemplate="<b>%{customdata}</b><extra></extra>",
label=[f"<span style='{theme.label_style}'>{s}</span>" for s in params.label],
label=[f"<span style='{self.theme.label_style}'>{s}</span>" for s in params.label],
color=params.node_palette,
)
link_kwargs = dict(
Expand All @@ -396,18 +401,18 @@ def generate_plot(self) -> go.Figure:
hovertemplate="%{customdata} <extra></extra>",
)

fig = go.FigureWidget(
data=[
go.Sankey(
arrangement="snap",
node={**node_kwargs, **theme.node_kwargs},
link={**link_kwargs, **theme.link_kwargs},
)
]
return go.Sankey(
arrangement="snap",
node={**node_kwargs, **self.theme.node_kwargs},
link={**link_kwargs, **self.theme.link_kwargs},
)

def _generate_figurewidget(self) -> go.FigureWidget:
"""Generate the FigureWidget that wraps the Sankey plot."""
fig = go.FigureWidget(data=[self._generate_sankey()])

fig.update_layout(
title_text=f"<span style='{theme.title_style}'>{self.title}</span>"
title_text=f"<span style='{self.theme.title_style}'>{self.title}</span>"
if self.title
else None,
font_size=16,
Expand Down
7 changes: 4 additions & 3 deletions tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@ def test_get_year_CORINE():

@pytest.mark.parametrize("dataset", sankee.datasets.datasets, ids=lambda d: d.name)
def test_years(dataset):
assert dataset.years == tuple(dataset.list_years().getInfo())
"""Check that the hard-coded dataset years match the Earth Engine catalog years."""
assert dataset.years == tuple(dataset._list_years().getInfo())


def test_get_unsupported_year():
Expand Down Expand Up @@ -117,8 +118,8 @@ def test_sankify():
title="My plot!",
)

params1 = sankey1.generate_plot_parameters()
params2 = sankey2.generate_plot_parameters()
params1 = sankey1._generate_plot_parameters()
params2 = sankey2._generate_plot_parameters()

for p1, p2 in zip(params1, params2):
assert_series_equal(p1, p2)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@ def sankey():

def test_get_sorted_classes(sankey):
"""Test that classes are correctly sorted."""
assert_series_equal(sankey.get_sorted_classes(), pd.Series([1, 2, 4, 3]), check_names=False)
assert_series_equal(sankey._get_sorted_classes(), pd.Series([1, 2, 4, 3]), check_names=False)


def test_plot_parameters(sankey):
"""Test that plot parameters are generated correctly."""
params = sankey.generate_plot_parameters()
params = sankey._generate_plot_parameters()
node_labels = ["start", "start", "start", "end", "end", "end", "end"]
label = [
"Agriculture",
Expand Down
Loading