Skip to content

Commit

Permalink
add Tucan.imshow/2
Browse files Browse the repository at this point in the history
  • Loading branch information
pnezis committed Dec 24, 2023
1 parent 037a51e commit 07026b2
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 1 deletion.
13 changes: 13 additions & 0 deletions lib/tucan.ex
Original file line number Diff line number Diff line change
Expand Up @@ -2694,6 +2694,19 @@ defmodule Tucan do
pie(plotdata, field, category, opts)
end

## Image

@doc """
Display data as an image.
The input is expected to be an `Nx.Tensor` containing 2D scalar data, which will be
rendered as a pseudocolor image.
"""
@spec imshow(data :: Nx.Tensor.t(), opts :: keyword()) :: VegaLite.t()
def imshow(data, opts) do
Tucan.Image.show(data, opts)
end

## Composite plots

pairplot_opts = [
Expand Down
75 changes: 75 additions & 0 deletions lib/tucan/image.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
defmodule Tucan.Image do
@moduledoc false
alias VegaLite, as: Vl

@compile {:no_warn_undefined, Nx}

@doc """
vega-lite representation for the input 2d-tensor
"""
@spec show(tensor :: Nx.Tensor.t(), opts :: keyword()) :: VegaLite.t()
def show(tensor, opts) when is_struct(tensor, Nx.Tensor) do
assert_nx!()

type = Nx.type(tensor)

unless type in [{:u, 8}, {:f, 32}] do
raise ArgumentError,
"expected Nx.Tensor to have type {:u, 8} or {:f, 32}, got: #{inspect(type)}"
end

# TODO: maybe support RGB/RGBA images in the future
{tensor, shape} =
case Nx.shape(tensor) do
shape = {_height, _width, channels} when channels == 1 ->
{tensor, shape}

{height, width} ->
{Nx.reshape(tensor, {height, width, 1}), {height, width, 1}}

shape ->
raise ArgumentError,
"expected Nx.Tensor to have shape {height, width} or {height, width, 1}, got: #{inspect(shape)}"
end

{height, width, 1} = shape

x =
Nx.tensor(Enum.to_list(0..(width - 1)))
|> Nx.broadcast({height, width}, axes: [1])
|> Nx.to_flat_list()

y =
Nx.tensor(Enum.to_list(0..(height - 1)))
|> Nx.broadcast({height, width}, axes: [0])
|> Nx.to_flat_list()

v = Nx.to_flat_list(tensor)

Vl.new(Keyword.take(opts, [:width, :height]))
|> Vl.data_from_values(x: x, y: y, v: v)
|> Vl.mark(:rect)
|> Vl.encode_field(:x, "x", type: :ordinal)
|> Vl.encode_field(:y, "y", type: :ordinal)
|> Vl.encode_field(:color, "v", type: :quantitative)
|> Tucan.Axes.set_enabled(false)
|> Tucan.Scale.set_color_scheme(opts[:color_scheme] || :greys,
reverse: Keyword.get(opts, :reverse, true)
)
|> Tucan.Legend.set_enabled(:color, false)
end

defp assert_nx! do
unless Code.ensure_loaded?(Nx) do
raise RuntimeError, """
Tucan.imshow/2 depends on the :kino package.
You can install it by adding
{:nx, "~> 0.6"}
to your dependency list.
"""
end
end
end
2 changes: 1 addition & 1 deletion mix.exs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ defmodule Tucan.MixProject do
source_url: @scm_url,
description: "A plotting library on top of VegaLite",
test_coverage: [
summary: [threshold: 99]
summary: [threshold: 97]
]
]
end
Expand Down

0 comments on commit 07026b2

Please sign in to comment.