From 32f8bc9c43cb36cb08fd4c4e0eb06da3b548eeea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Valim?= Date: Thu, 6 Mar 2025 10:01:13 +0100 Subject: [PATCH] Do not discard nil on protocol concat, closes #14311 --- lib/elixir/lib/module.ex | 6 ++- lib/elixir/lib/protocol.ex | 47 +++++++++++++++--------- lib/elixir/test/elixir/protocol_test.exs | 7 +++- 3 files changed, 39 insertions(+), 21 deletions(-) diff --git a/lib/elixir/lib/module.ex b/lib/elixir/lib/module.ex index 9033341997..1c66685ad1 100644 --- a/lib/elixir/lib/module.ex +++ b/lib/elixir/lib/module.ex @@ -950,7 +950,8 @@ defmodule Module do @doc """ Concatenates two aliases and returns a new alias. - It handles binaries and atoms. + It handles binaries and atoms. If one of the aliases + is nil, it is discarded. ## Examples @@ -960,6 +961,9 @@ defmodule Module do iex> Module.concat(Foo, "Bar") Foo.Bar + iex> Module.concat(Foo, nil) + Foo + """ @spec concat(binary | atom, binary | atom) :: atom def concat(left, right) diff --git a/lib/elixir/lib/protocol.ex b/lib/elixir/lib/protocol.ex index 1ce76f152c..d8dc803189 100644 --- a/lib/elixir/lib/protocol.ex +++ b/lib/elixir/lib/protocol.ex @@ -394,7 +394,7 @@ defmodule Protocol do end defp assert_impl!(protocol, base, extra) do - impl = Module.concat(protocol, base) + impl = __concat__(protocol, base) try do Code.ensure_compiled!(impl) @@ -684,14 +684,14 @@ defmodule Protocol do types |> List.delete(Any) |> Enum.map(fn impl -> - {[Module.Types.Of.impl(impl)], Descr.atom([Module.concat(protocol, impl)])} + {[Module.Types.Of.impl(impl)], Descr.atom([__concat__(protocol, impl)])} end) {domain, impl_for, impl_for!} = case clauses do [] -> if Any in types do - clauses = [{[Descr.term()], Descr.atom([Module.concat(protocol, Any)])}] + clauses = [{[Descr.term()], Descr.atom([__concat__(protocol, Any)])}] {Descr.term(), clauses, clauses} else {Descr.none(), [{[Descr.term()], Descr.atom([nil])}], @@ -707,7 +707,9 @@ defmodule Protocol do not_domain = Descr.negation(domain) if Any in types do - clauses = clauses ++ [{[not_domain], Descr.atom([Module.concat(protocol, Any)])}] + clauses = + clauses ++ [{[not_domain], Descr.atom([__concat__(protocol, Any)])}] + {Descr.term(), clauses, clauses} else {domain, clauses ++ [{[not_domain], Descr.atom([nil])}], clauses} @@ -746,7 +748,7 @@ defmodule Protocol do end defp change_impl_for({_name, _kind, meta, _clauses}, protocol, types) do - fallback = if Any in types, do: load_impl(protocol, Any) + fallback = if Any in types, do: __concat__(protocol, Any) line = meta[:line] clauses = @@ -762,7 +764,7 @@ defmodule Protocol do end defp change_struct_impl_for({_name, _kind, meta, _clauses}, protocol, types, structs) do - fallback = if Any in types, do: load_impl(protocol, Any) + fallback = if Any in types, do: __concat__(protocol, Any) clauses = for struct <- structs, do: each_struct_clause_for(struct, protocol, meta) clauses = clauses ++ [fallback_clause_for(fallback, protocol, meta)] @@ -772,7 +774,7 @@ defmodule Protocol do defp built_in_clause_for(mod, guard, protocol, meta, line) do x = {:x, [line: line, version: -1], __MODULE__} guard = quote(line: line, do: :erlang.unquote(guard)(unquote(x))) - body = load_impl(protocol, mod) + body = __concat__(protocol, mod) {meta, [x], [guard], body} end @@ -785,17 +787,13 @@ defmodule Protocol do end defp each_struct_clause_for(struct, protocol, meta) do - {meta, [struct], [], load_impl(protocol, struct)} + {meta, [struct], [], __concat__(protocol, struct)} end defp fallback_clause_for(value, _protocol, meta) do {meta, [quote(do: _)], [], value} end - defp load_impl(protocol, for) do - Module.concat(protocol, for) - end - # Finally compile the module and emit its bytecode. defp compile(definitions, signatures, {module_map, specs, docs_chunk}) do # Protocols in precompiled archives may not have signatures, so we default to an empty map. @@ -957,7 +955,7 @@ defmodule Protocol do # Define the implementation for built-ins :lists.foreach( fn {mod, guard} -> - target = Module.concat(__MODULE__, mod) + target = Protocol.__concat__(__MODULE__, mod) Kernel.def impl_for(data) when :erlang.unquote(guard)(data) do case Code.ensure_compiled(unquote(target)) do @@ -1001,7 +999,7 @@ defmodule Protocol do # Internal handler for Structs Kernel.defp struct_impl_for(struct) do - case Code.ensure_compiled(Module.concat(__MODULE__, struct)) do + case Code.ensure_compiled(Protocol.__concat__(__MODULE__, struct)) do {:module, module} -> module {:error, _} -> unquote(any_impl_for) end @@ -1074,7 +1072,7 @@ defmodule Protocol do quote do protocol = unquote(protocol) for = unquote(for) - name = Module.concat(protocol, for) + name = Protocol.__concat__(protocol, for) Protocol.assert_protocol!(protocol) Protocol.__impl__!(protocol, for, __ENV__) @@ -1120,7 +1118,7 @@ defmodule Protocol do else # TODO: Deprecate this on Elixir v1.22+ assert_impl!(protocol, Any, extra) - {Module.concat(protocol, Any), [for, Macro.struct!(for, env), opts]} + {__concat__(protocol, Any), [for, Macro.struct!(for, env), opts]} end # Clean up variables from eval context @@ -1132,7 +1130,7 @@ defmodule Protocol do else __impl__!(protocol, for, env) assert_impl!(protocol, Any, extra) - impl = Module.concat(protocol, Any) + impl = __concat__(protocol, Any) funs = for {fun, arity} <- protocol.__protocol__(:functions) do @@ -1157,7 +1155,7 @@ defmodule Protocol do def __impl__(:for), do: unquote(for) end - Module.create(Module.concat(protocol, for), [quoted | funs], Macro.Env.location(env)) + Module.create(__concat__(protocol, for), [quoted | funs], Macro.Env.location(env)) end end) end @@ -1204,4 +1202,17 @@ defmodule Protocol do {Reference, :is_reference} ] end + + @doc false + def __concat__(left, right) do + String.to_atom( + ensure_prefix(Atom.to_string(left)) <> "." <> remove_prefix(Atom.to_string(right)) + ) + end + + defp ensure_prefix("Elixir." <> _ = left), do: left + defp ensure_prefix(left), do: "Elixir." <> left + + defp remove_prefix("Elixir." <> right), do: right + defp remove_prefix(right), do: right end diff --git a/lib/elixir/test/elixir/protocol_test.exs b/lib/elixir/test/elixir/protocol_test.exs index e2ed8a4b51..4319d69c61 100644 --- a/lib/elixir/test/elixir/protocol_test.exs +++ b/lib/elixir/test/elixir/protocol_test.exs @@ -110,15 +110,18 @@ defmodule ProtocolTest do assert Sample.impl_for(%ImplStruct{}) == Sample.ProtocolTest.ImplStruct assert Sample.impl_for(%ImplStructExplicitFor{}) == Sample.ProtocolTest.ImplStructExplicitFor assert Sample.impl_for(%NoImplStruct{}) == nil + assert is_nil(Sample.impl_for(%{__struct__: nil})) end test "protocol implementation with Any and struct fallbacks" do assert WithAny.impl_for(%NoImplStruct{}) == WithAny.Any - # Derived - assert WithAny.impl_for(%ImplStruct{}) == ProtocolTest.WithAny.ProtocolTest.ImplStruct + assert WithAny.impl_for(%{__struct__: nil}) == WithAny.Any assert WithAny.impl_for(%{__struct__: "foo"}) == WithAny.Map assert WithAny.impl_for(%{}) == WithAny.Map assert WithAny.impl_for(self()) == WithAny.Any + + # Derived + assert WithAny.impl_for(%ImplStruct{}) == ProtocolTest.WithAny.ProtocolTest.ImplStruct end test "protocol not implemented" do