Skip to content

Commit

Permalink
Do not discard nil on protocol concat, closes #14311 (#14314)
Browse files Browse the repository at this point in the history
  • Loading branch information
josevalim committed Mar 6, 2025
1 parent 75677b9 commit 178643f
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 11 deletions.
6 changes: 5 additions & 1 deletion lib/elixir/lib/module.ex
Original file line number Diff line number Diff line change
Expand Up @@ -946,7 +946,8 @@ defmodule Module do
@doc """
Concatenates two aliases and returns a new alias.
It handles binaries and atoms.
It handles binaries and atoms. If one of the aliases
is nil, it is discarded.
## Examples
Expand All @@ -956,6 +957,9 @@ defmodule Module do
iex> Module.concat(Foo, "Bar")
Foo.Bar
iex> Module.concat(Foo, nil)
Foo
"""
@spec concat(binary | atom, binary | atom) :: atom
def concat(left, right)
Expand Down
33 changes: 25 additions & 8 deletions lib/elixir/lib/protocol.ex
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ defmodule Protocol do
end

defp assert_impl!(protocol, base, extra) do
impl = Module.concat(protocol, base)
impl = Protocol.__concat__(protocol, base)

try do
Code.ensure_compiled!(impl)
Expand Down Expand Up @@ -678,7 +678,7 @@ defmodule Protocol do
end

defp load_impl(protocol, for) do
Module.concat(protocol, for)
Protocol.__concat__(protocol, for)
end

# Finally compile the module and emit its bytecode.
Expand Down Expand Up @@ -831,7 +831,7 @@ defmodule Protocol do
# Define the implementation for built-ins
:lists.foreach(
fn {guard, mod} ->
target = Module.concat(__MODULE__, mod)
target = Protocol.__concat__(__MODULE__, mod)

Kernel.def impl_for(data) when :erlang.unquote(guard)(data) do
case Code.ensure_compiled(unquote(target)) do
Expand Down Expand Up @@ -875,7 +875,7 @@ defmodule Protocol do

# Internal handler for Structs
Kernel.defp struct_impl_for(struct) do
case Code.ensure_compiled(Module.concat(__MODULE__, struct)) do
case Code.ensure_compiled(Protocol.__concat__(__MODULE__, struct)) do
{:module, module} -> module
{:error, _} -> unquote(any_impl_for)
end
Expand Down Expand Up @@ -948,7 +948,7 @@ defmodule Protocol do
quote do
protocol = unquote(protocol)
for = unquote(for)
name = Module.concat(protocol, for)
name = Protocol.__concat__(protocol, for)

Protocol.assert_protocol!(protocol)
Protocol.__ensure_defimpl__(protocol, for, __ENV__)
Expand Down Expand Up @@ -994,7 +994,7 @@ defmodule Protocol do
else
# TODO: Deprecate this on Elixir v1.22+
assert_impl!(protocol, Any, extra)
{Module.concat(protocol, Any), [for, Macro.struct!(for, env), opts]}
{Protocol.__concat__(protocol, Any), [for, Macro.struct!(for, env), opts]}
end

# Clean up variables from eval context
Expand All @@ -1006,7 +1006,7 @@ defmodule Protocol do
else
__ensure_defimpl__(protocol, for, env)
assert_impl!(protocol, Any, extra)
impl = Module.concat(protocol, Any)
impl = Protocol.__concat__(protocol, Any)

funs =
for {fun, arity} <- protocol.__protocol__(:functions) do
Expand All @@ -1031,7 +1031,11 @@ defmodule Protocol do
def __impl__(:for), do: unquote(for)
end

Module.create(Module.concat(protocol, for), [quoted | funs], Macro.Env.location(env))
Module.create(
Protocol.__concat__(protocol, for),
[quoted | funs],
Macro.Env.location(env)
)
end
end)
end
Expand Down Expand Up @@ -1070,4 +1074,17 @@ defmodule Protocol do
is_reference: Reference
]
end

@doc false
def __concat__(left, right) do
String.to_atom(
ensure_prefix(Atom.to_string(left)) <> "." <> remove_prefix(Atom.to_string(right))
)
end

defp ensure_prefix("Elixir." <> _ = left), do: left
defp ensure_prefix(left), do: "Elixir." <> left

defp remove_prefix("Elixir." <> right), do: right
defp remove_prefix(right), do: right
end
7 changes: 5 additions & 2 deletions lib/elixir/test/elixir/protocol_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -106,15 +106,18 @@ defmodule ProtocolTest do
assert Sample.impl_for(%ImplStruct{}) == Sample.ProtocolTest.ImplStruct
assert Sample.impl_for(%ImplStructExplicitFor{}) == Sample.ProtocolTest.ImplStructExplicitFor
assert Sample.impl_for(%NoImplStruct{}) == nil
assert is_nil(Sample.impl_for(%{__struct__: nil}))
end

test "protocol implementation with Any and struct fallbacks" do
assert WithAny.impl_for(%NoImplStruct{}) == WithAny.Any
# Derived
assert WithAny.impl_for(%ImplStruct{}) == ProtocolTest.WithAny.ProtocolTest.ImplStruct
assert WithAny.impl_for(%{__struct__: nil}) == WithAny.Any
assert WithAny.impl_for(%{__struct__: "foo"}) == WithAny.Map
assert WithAny.impl_for(%{}) == WithAny.Map
assert WithAny.impl_for(self()) == WithAny.Any

# Derived
assert WithAny.impl_for(%ImplStruct{}) == ProtocolTest.WithAny.ProtocolTest.ImplStruct
end

test "protocol not implemented" do
Expand Down

0 comments on commit 178643f

Please sign in to comment.