Skip to content
Permalink

Comparing changes

This is a direct comparison between two commits made in this repository or its related repositories. View the default comparison for this range or learn more about diff comparisons.

Open a pull request

Create a new pull request by comparing changes across two branches. If you need to, you can also . Learn more about diff comparisons here.
base repository: astropy/astrowidgets
Failed to load repositories. Confirm that selected base ref is valid, then try again.
Loading
base: 41d33359c8b538734c03fa3219e5f32c4d2af70d
Choose a base ref
..
head repository: astropy/astrowidgets
Failed to load repositories. Confirm that selected head ref is valid, then try again.
Loading
compare: 1e71af4b5a8967f7024a6916a4d6001362a1a18b
Choose a head ref
Showing with 266 additions and 82 deletions.
  1. +245 −72 astrowidgets/bqplot.py
  2. +21 −10 astrowidgets/tests/test_bqplot_api.py
317 changes: 245 additions & 72 deletions astrowidgets/bqplot.py
Original file line number Diff line number Diff line change
@@ -3,10 +3,11 @@
from astropy.coordinates import SkyCoord
from astropy.io import fits
from astropy.nddata import CCDData
from astropy.table import Table, vstack
from astropy import units as u
import astropy.visualization as apviz

from bqplot import Figure, LinearScale, Axis, ColorScale, PanZoom
from bqplot import Figure, LinearScale, Axis, ColorScale, PanZoom, ScatterGL
from bqplot_image_gl import ImageGL
from bqplot_image_gl.interacts import (MouseInteraction,
keyboard_events, mouse_events)
@@ -61,6 +62,8 @@ def __init__(self, image_data=None,
'image': ColorScale(max=1.114, min=2902,
scheme='Greys')}

self._scatter_marks = {}

self._figure = Figure(scales=self._scales, axes=[axis_x, axis_y],
fig_margin=dict(top=0, left=0,
right=0, bottom=0),
@@ -79,6 +82,10 @@ def __init__(self, image_data=None,

self._figure.interaction = interaction

# Keep track of this separately so that it is easy to change
# its state.
self._panzoom = panzoom

if image_data:
self.set_data(image_data, reset_view=True)

@@ -180,6 +187,67 @@ def save_png(self, filename):
def save_svg(self, filename):
self._figure.save_svg(filename)

def set_pan(self, on_or_off):
self._panzoom.allow_pan = on_or_off

def set_scroll_zoom(self, on_or_off):
self._panzoom.allow_zoom = on_or_off

def set_size(self, size, direction):
scale_to_set = self._scales[direction]
cen = {}
cen['x'], cen['y'] = self.center
scale_to_set.min = cen[direction] - size/2
scale_to_set.max = cen[direction] + size/2

reset_scale = 'x' if direction == 'y' else 'y'

self._set_scale_aspect_ratio_to_match_viewer(reset_scale)

def set_zoom_level(self, zoom_level):
"""
Set zoom level of viewer. A zoom level of 1 means 1 pixel
in the image is 1 pixel in the viewer, i.e. the scale width
in the horizontal direction matches the width in pixels
of the figure.
"""

# The width is reset here but the height could be set instead
# and the result would be the same.
figure_width = float(self._figure.layout.width[:-2])
new_width = figure_width / zoom_level
self.set_size(new_width, 'x')
self._set_scale_aspect_ratio_to_match_viewer('y')

def plot_named_markers(self, x, y, mark_id, color='yellow',
size=100, style='circle'):
scale_dict = dict(x=self._scales['x'], y=self._scales['y'])
sc = ScatterGL(scales=scale_dict,
x=x, y=y,
colors=[color],
default_size=100,
marker=style,
fill=False)

self._scatter_marks[mark_id] = sc
self._update_marks()

def remove_named_markers(self, mark_id):
try:
del self._scatter_marks[mark_id]
except KeyError:
raise ValueError('Markers {mark_id} are not present.')

self._update_marks()

def remove_markers(self):
self._scatter_marks = {}
self._update_marks()

def _update_marks(self):
marks = [self._image] + [mark for mark in self._scatter_marks.values()]
self._figure.marks = marks


def bqcolors(colormap, reverse=False):
# bqplot-image-gl has 256 levels
@@ -201,68 +269,66 @@ def bqcolors(colormap, reverse=False):


class MarkerTableManager:
"""
Table for keeping track of positions and names of sets of
logically-related markers.
"""
def __init__(self):
pass
# These column names are for internal use.
self._xcol = 'x'
self._ycol = 'y'
self._names = 'name'
self._marktags = set()
# Let's have a default name for the tag too:
self.default_mark_tag_name = 'default-marker-name'
self._interactive_marker_set_name_default = 'interactive-markers'
self._interactive_marker_set_name = self._interactive_marker_set_name_default
self._init_table()

def _init_table(self):
self._table = Table(names=(self._xcol, self._ycol, self._names),
dtype=('int32', 'int32', 'str'))

def add_markers(self, table, x_colname='x', y_colname='y',
skycoord_colname='coord', use_skycoord=False,
marker_name=None):
@property
def xcol(self):
return self._xcol

@property
def ycol(self):
return self._ycol

@property
def names(self):
return self._names

# For now we always convert marker locations to pixels; see
# comment below.
coord_type = 'data'
@property
def marker_names(self):
return sorted(set(self._table[self.names]))

def add_markers(self, x_mark, y_mark,
marker_name=None):

if marker_name is None:
marker_name = self._default_mark_tag_name
marker_name = self.default_mark_tag_name

self.validate_marker_name(marker_name)
self._marktags.add(marker_name)
for x, y in zip(x_mark, y_mark):
self._table.add_row([x, y, marker_name])

# Extract coordinates from table.
# They are always arrays, not scalar.
if use_skycoord:
image = self._viewer.get_image()
if image is None:
raise ValueError('Cannot get image from viewer')
if image.wcs.wcs is None:
raise ValueError(
'Image has no valid WCS, '
'try again with use_skycoord=False')
coord_val = table[skycoord_colname]
# TODO: Maybe switch back to letting Ginga handle conversion
# to pixel coordinates.
# Convert to pixels here (instead of in Ginga) because conversion
# in Ginga was reportedly very slow.
coord_x, coord_y = image.wcs.wcs.all_world2pix(
coord_val.ra.deg, coord_val.dec.deg, 0)
# In the event a *single* marker has been added, coord_x and coord_y
# will be scalars. Make them arrays always.
if np.ndim(coord_x) == 0:
coord_x = np.array([coord_x])
coord_y = np.array([coord_y])
else: # Use X,Y
coord_x = table[x_colname].data
coord_y = table[y_colname].data
# Convert data coordinates from 1-indexed to 0-indexed
if self._pixel_offset != 0:
# Don't use the in-place operator -= here that modifies
# the input table.
coord_x = coord_x - self._pixel_offset
coord_y = coord_y - self._pixel_offset

# Prepare canvas and retain existing marks
try:
c_mark = self._viewer.canvas.get_object_by_tag(marker_name)
except Exception:
objs = []
else:
objs = c_mark.objects
self._viewer.canvas.delete_object_by_tag(marker_name)
def get_markers_by_name(self, marker_name):
matches = self._table[self._names] == marker_name
return self._table[matches]

def get_all_markers(self):
return self._table.copy()

def remove_markers_by_name(self, marker_name):
matches = self._table[self._names] == marker_name
# Only keep the things that don't match
self._table = self._table[~matches]

# TODO: Test to see if we can mix WCS and data on the same canvas
objs += [self._marker(x=x, y=y, coord=coord_type)
for x, y in zip(coord_x, coord_y)]
self._viewer.canvas.add(self.dc.CompoundObject(*objs), tag=marker_name)
def remove_all_markers(self):
self._init_table()


"""
@@ -298,6 +364,11 @@ def __init__(self, *args, image_width=500, image_height=500):
self._interval = None
self._stretch = None
self._colormap = 'Grays'
self._marker_table = MarkerTableManager()
self._data = None
self._wcs = None
self._is_marking = False
self.marker = {'color': 'red', 'radius': 20, 'type': 'square'}

def _interval_and_stretch(self):
"""
@@ -382,6 +453,35 @@ def _observe_cuts(self, change):
else:
self._interval = cuts

@trait.observe('zoom_level')
def _update_zoom_level(self, change):
zl = change['new']

self._astro_im.set_zoom_level(zl)

@trait.validate('click_drag')
def _validate_click_drag(self, proposal):
cd = proposal['value']
if cd and self._is_marking:
raise ValueError('Cannot set click_drag while doing interactive '
'marking. Call the stop_marking() method to '
'stop marking and then set click_drag.')
return cd

@trait.observe('click_drag')
def _update_viewer_pan(self, change):
# Turn of click-to-center
if change['new']:
self.click_center = False

self._astro_im.set_pan(change['new'])

@trait.observe('scroll_pan')
def _update_viewer_zoom_scroll(self, change):
raise NotImplementedError('😭 sorry, cannot do that yet')
self._astro_im.set_scroll_zoom(change['new'])


# The methods, grouped loosely by purpose

# Methods for loading data
@@ -403,6 +503,7 @@ def load_fits(self, file_name_or_HDU):

self._ccd = ccd
self._data = ccd.data
self._wcs = ccd.wcs
self._send_data()

def load_array(self, array):
@@ -441,25 +542,98 @@ def colormap_options(self):
# def stop_marking(self):
# raise NotImplementedError

# @abstractmethod
# def add_markers(self):
# raise NotImplementedError
def add_markers(self, table, x_colname='x', y_colname='y',
skycoord_colname='coord', use_skycoord=False,
marker_name=None):

# @abstractmethod
# def get_markers(self):
# raise NotImplementedError
if use_skycoord:
if self._wcs is None:
raise ValueError('The WCS for the image must be set to use '
'world coordinates for markers.')

# @abstractmethod
# def remove_markers(self):
# raise NotImplementedError
x, y = self._wcs.world_to_pixel(table[skycoord_colname])
else:
x = table[x_colname]
y = table[y_colname]

# @abstractmethod
# def get_all_markers(self):
# raise NotImplementedError
# Update the table of marker names and positions
self._marker_table.add_markers(x, y, marker_name=marker_name)

# @abstractmethod
# def get_markers_by_name(self, marker_name=None):
# raise NotImplementedError
# Update the figure itself, which expects all markers of
# the same name to be plotted at once.
marks = self.get_markers_by_name(marker_name)

self._astro_im.plot_named_markers(marks['x'], marks['y'],
marker_name,
color=self.marker['color'],
size=self.marker['radius']**2,
style=self.marker['type'])

def remove_markers_by_name(self, marker_name):
# Remove from our tracking table
self._marker_table.remove_markers_by_name(marker_name)

# Remove from the visible canvas
self._astro_im.remove_named_markers(marker_name)

def remove_all_markers(self):
self._marker_table.remove_all_markers()
self._astro_im.remove_markers()

def _prepare_return_marker_table(self, marks, x_colname='x', y_colname='y',
skycoord_colname='coord'):
if len(marks) == 0:
return None

if (self._data is None) or (self._wcs is None):
# Do not include SkyCoord column
include_skycoord = False
else:
include_skycoord = True
radec_col = []

if include_skycoord:
coords = self._wcs.pixel_to_world(marks[self._marker_table.xcol],
marks[self._marker_table.ycol])
marks[skycoord_colname] = coords

# This might be a null op but should be harmless in that case
marks.rename_column(self._marker_table.xcol, x_colname)
marks.rename_column(self._marker_table.ycol, y_colname)

return marks

def get_markers_by_name(self, marker_name=None, x_colname='x', y_colname='y',
skycoord_colname='coord'):

# We should always allow the default name. The case
# where that table is empty will be handled in a moment.
if (marker_name not in self._marker_table.marker_names
and marker_name != self.marker_table.default_mark_tag_name):
raise ValueError(f"No markers named '{marker_name}' found.")

marks = self._marker_table.get_markers_by_name(marker_name=marker_name)

if len(marks) == 0:
# No markers in this table. Issue a warning and continue.
# Test wants this outside of logger, so...
warnings.warn(f"Marker set named '{marker_name}' is empty", UserWarning)
return None

marks = self._prepare_return_marker_table(marks,
x_colname=x_colname,
y_colname=y_colname,
skycoord_colname=skycoord_colname)
return marks

def get_all_markers(self, x_colname='x', y_colname='y',
skycoord_colname='coord'):
marks = self._marker_table.get_all_markers()
marks = self._prepare_return_marker_table(marks,
x_colname=x_colname,
y_colname=y_colname,
skycoord_colname=skycoord_colname)
return marks

# Methods that modify the view
def center_on(self, point):
@@ -477,6 +651,5 @@ def center_on(self, point):
# def offset_to(self):
# raise NotImplementedError

# @abstractmethod
# def zoom(self):
# raise NotImplementedError
def zoom(self, value):
self.zoom_level = self.zoom_level * value
Loading