Skip to content

Commit

Permalink
support passing Nx.Tensor as series in Tucan.new
Browse files Browse the repository at this point in the history
  • Loading branch information
pnezis committed Dec 23, 2023
1 parent 40c92d4 commit 8f15394
Show file tree
Hide file tree
Showing 4 changed files with 171 additions and 1 deletion.
142 changes: 141 additions & 1 deletion lib/tucan.ex
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,18 @@ defmodule Tucan do
Tucan.scatter(:iris, "petal_width", "petal_length")
```
> #### `Nx` support {: .neutral}
>
> If `:nx` is installed as a dependency you can additionally pass directly the data
> columns as tensors. For example:
>
> ```tucan
> x = Nx.linspace(-20, 20, n: 200)
> y = Nx.pow(x, 2)
>
> Tucan.lineplot([x: x, y: y], "x", "y", width: 400)
> ```
You can apply semantic grouping by a third variable by modifying the color, the
shape or the size of the points:
Expand Down Expand Up @@ -160,7 +172,101 @@ defmodule Tucan do
* if it is an atom then it is considered a `Tucan.Dataset` and it is translated to
the dataset's url. If the dataset name is invalid an exception is raised.
* in any other case it is considered a set of data values and the values are set
as data to a newly created `VegaLite` struct.
as data to a newly created `VegaLite` struct. Any tabular data is accepted, as
long as it adheres to the `Table.Reader` protocol.
## Examples
Passing a URL to some dataset
```elixir
Tucan.new("https://vega.github.io/editor/data/penguins.json")
|> ...
Tucan.new("https://vega.github.io/editor/data/stocks.csv", format: :csv)
|> ...
```
Using a pre-defined `Tucan.Dataset`
```elixir
Tucan.new(:penguins)
|> ...
Tucan.new(:Lris)
|> ...
```
Passing directly tabular data
```elixir
data = [
%{"category" => "A", "score" => 28},
%{"category" => "B", "score" => 55}
]
Tucan.new(data)
|> ...
```
You can also pass individual series:
```elixir
xs = 1..100
ys = 1..100
Tucan.new(x: xs, y: ys)
|> ...
```
Any data that adheres to the `Table.Reader` protocol is accepted, for example
you could pass an [`Explorer.DataFrame`](https://hexdocs.pm/explorer/Explorer.DataFrame.html)
```elixir
mountains = Explorer.DataFrame.new(
name: ["Everest", "K2", "Aconcagua"],
elevation: [8848, 8611, 6962]
)
Tucan.new(mountains)
|> ...
```
Additionally you can pass `Nx.Tensor`s as series. These will be implicitly
transformed to lists.
```elixir
xs = Nx.linspace(0, 10, n: 100)
ys = Nx.sin(xs)
Tucan.new([x: xs, y: ys])
|> ...
```
> #### Valid `Nx.Tensor` shapes {: .info}
>
> 1-dimensional tensors are expected when you pass `Nx` tensors as series.
> Additionally for convenience 2-dimensional tensors where one of the two
> dimensions are `1` are also supported.
>
> For example the following are equivalent
>
> ```elixir
> x = Nx.linspace(0, 10, n: 10)
> y = Nx.pow(x, 2)
>
> plot1 = Tucan.new(x: x, y: y)
>
> x = Nx.reshape(x, {10, 1})
> y = Nx.reshape(y, {1, 10})
>
> plot2 = Tucan.new(x: x, y: y)
>
> assert plot1 == plot2
> ```
>
> For all other tensor shapes an `ArgumentError` will be raised.
"""
@doc section: :utilities
@spec new(plotdata :: plotdata(), opts :: keyword()) :: VegaLite.t()
Expand All @@ -184,11 +290,45 @@ defmodule Tucan do
defp to_vega_plot(data, opts) do
{data_opts, spec_opts} = Keyword.split(opts, [:only])

data = maybe_transform_data(data)

spec_opts
|> new_tucan_plot()
|> Vl.data_from_values(data, data_opts)
end

defp maybe_transform_data(data) do
case Keyword.keyword?(data) do
false ->
data

true ->
for {key, column} <- data do
{key, maybe_nx_to_list(column, key)}
end
end
end

@compile {:no_warn_undefined, Nx}

defp maybe_nx_to_list(column, name) when is_struct(column, Nx.Tensor) do
shape = Nx.shape(column)

unless valid_shape?(shape) do
raise ArgumentError,
"invalid shape for #{name} tensor, expected a 1-d tensor, got a #{inspect(shape)} tensor"
end

Nx.to_flat_list(column)
end

defp maybe_nx_to_list(column, _name), do: column

defp valid_shape?({_x}), do: true
defp valid_shape?({1, _x}), do: true
defp valid_shape?({_x, 1}), do: true
defp valid_shape?(_shape), do: false

defp new_tucan_plot(opts) do
{tucan_opts, opts} = Keyword.pop(opts, :tucan)

Expand Down
1 change: 1 addition & 0 deletions mix.exs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ defmodule Tucan.MixProject do
{:nimble_options, "~> 1.0"},
{:vega_lite, "~> 0.1.8"},
{:jason, "~> 1.4"},
{:nx, "~> 0.6", optional: true},
{:ex_doc, "~> 0.30", only: :dev, runtime: false},
{:fancy_fences, "~> 0.3.0", only: :dev, runtime: false}
] ++ dev_deps()
Expand Down
3 changes: 3 additions & 0 deletions mix.lock
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
%{
"bunt": {:hex, :bunt, "0.2.1", "e2d4792f7bc0ced7583ab54922808919518d0e57ee162901a16a1b6664ef3b14", [:mix], [], "hexpm", "a330bfb4245239787b15005e66ae6845c9cd524a288f0d141c148b02603777a5"},
"complex": {:hex, :complex, "0.5.0", "af2d2331ff6170b61bb738695e481b27a66780e18763e066ee2cd863d0b1dd92", [:mix], [], "hexpm", "2683bd3c184466cfb94fad74cbfddfaa94b860e27ad4ca1bffe3bff169d91ef1"},
"credo": {:hex, :credo, "1.7.1", "6e26bbcc9e22eefbff7e43188e69924e78818e2fe6282487d0703652bc20fd62", [:mix], [{:bunt, "~> 0.2.1", [hex: :bunt, repo: "hexpm", optional: false]}, {:file_system, "~> 0.2.8", [hex: :file_system, repo: "hexpm", optional: false]}, {:jason, "~> 1.0", [hex: :jason, repo: "hexpm", optional: false]}], "hexpm", "e9871c6095a4c0381c89b6aa98bc6260a8ba6addccf7f6a53da8849c748a58a2"},
"decimal": {:hex, :decimal, "2.1.1", "5611dca5d4b2c3dd497dec8f68751f1f1a54755e8ed2a966c2633cf885973ad6", [:mix], [], "hexpm", "53cfe5f497ed0e7771ae1a475575603d77425099ba5faef9394932b35020ffcc"},
"dialyxir": {:hex, :dialyxir, "1.4.2", "764a6e8e7a354f0ba95d58418178d486065ead1f69ad89782817c296d0d746a5", [:mix], [{:erlex, ">= 0.2.6", [hex: :erlex, repo: "hexpm", optional: false]}], "hexpm", "516603d8067b2fd585319e4b13d3674ad4f314a5902ba8130cd97dc902ce6bbd"},
Expand All @@ -15,6 +16,8 @@
"makeup_erlang": {:hex, :makeup_erlang, "0.1.2", "ad87296a092a46e03b7e9b0be7631ddcf64c790fa68a9ef5323b6cbb36affc72", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}], "hexpm", "f3f5a1ca93ce6e092d92b6d9c049bcda58a3b617a8d888f8e7231c85630e8108"},
"nimble_options": {:hex, :nimble_options, "1.0.2", "92098a74df0072ff37d0c12ace58574d26880e522c22801437151a159392270e", [:mix], [], "hexpm", "fd12a8db2021036ce12a309f26f564ec367373265b53e25403f0ee697380f1b8"},
"nimble_parsec": {:hex, :nimble_parsec, "1.4.0", "51f9b613ea62cfa97b25ccc2c1b4216e81df970acd8e16e8d1bdc58fef21370d", [:mix], [], "hexpm", "9c565862810fb383e9838c1dd2d7d2c437b3d13b267414ba6af33e50d2d1cf28"},
"nx": {:hex, :nx, "0.6.4", "948d9f42f81e63fc901d243ac0a985c8bb87358be62e27826cfd67f58bc640af", [:mix], [{:complex, "~> 0.5", [hex: :complex, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "bb9c2e2e3545b5eb4739d69046a988daaa212d127dba7d97801c291616aff6d6"},
"table": {:hex, :table, "0.1.2", "87ad1125f5b70c5dea0307aa633194083eb5182ec537efc94e96af08937e14a8", [:mix], [], "hexpm", "7e99bc7efef806315c7e65640724bf165c3061cdc5d854060f74468367065029"},
"telemetry": {:hex, :telemetry, "1.2.1", "68fdfe8d8f05a8428483a97d7aab2f268aaff24b49e0f599faa091f1d4e7f61c", [:rebar3], [], "hexpm", "dad9ce9d8effc621708f99eac538ef1cbe05d6a874dd741de2e689c47feafed5"},
"vega_lite": {:hex, :vega_lite, "0.1.8", "7f6119126ecaf4bc2c1854084370d7091424f5cce4795fbac044eee9963f0752", [:mix], [{:table, "~> 0.1.0", [hex: :table, repo: "hexpm", optional: false]}], "hexpm", "6c8a9271f850612dd8a90de8d1ebd433590ed07ffef76fc2397c240dc04d3fdc"},
}
26 changes: 26 additions & 0 deletions test/tucan_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,32 @@ defmodule TucanTest do
vl = Tucan.new(:iris, width: 100, height: 100, foo: 2, tucan: [plot: true])
assert get_in(vl.spec, ["__tucan__"]) == %{"plot" => true}
end

test "with nx tensors" do
x = Nx.linspace(0, 4, n: 5)
y = Nx.add(x, 1)

expected =
Vl.new(width: 100, height: 100)
|> Vl.data_from_values(x: [0, 1, 2, 3, 4], y: [1, 2, 3, 4, 5])

assert Tucan.new([x: x, y: y], width: 100, height: 100) == expected

x = Nx.reshape(x, {5, 1})
y = Nx.reshape(y, {1, 5})

assert Tucan.new([x: x, y: y], width: 100, height: 100) == expected

assert Tucan.new([x: x, y: 1..5], width: 100, height: 100) == expected
end

test "raises with invalid nx shape" do
x = Nx.linspace(0, 10, n: 10) |> Nx.reshape({2, 5})

assert_raise ArgumentError,
"invalid shape for x tensor, expected a 1-d tensor, got a {2, 5} tensor",
fn -> Tucan.new(x: x) end
end
end

describe "histogram/3" do
Expand Down

0 comments on commit 8f15394

Please sign in to comment.