Skip to content

Commit

Permalink
[Metal] Add correct addrspace to global constants (#648)
Browse files Browse the repository at this point in the history
Co-authored-by: Tim Besard <[email protected]>
  • Loading branch information
tgymnich and maleadt authored Nov 29, 2024
1 parent d77b429 commit a849e8a
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 7 deletions.
102 changes: 96 additions & 6 deletions src/metal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,8 @@ function finish_ir!(@nospecialize(job::CompilerJob{MetalCompilerTarget}), mod::L

# add kernel metadata
if job.config.kernel
entry = add_address_spaces!(job, mod, entry)
entry = add_parameter_address_spaces!(job, mod, entry)
entry = add_global_address_spaces!(job, mod, entry)

add_argument_metadata!(job, mod, entry)

Expand Down Expand Up @@ -199,10 +200,12 @@ function finish_ir!(@nospecialize(job::CompilerJob{MetalCompilerTarget}), mod::L
end

# perform codegen passes that would normally run during machine code emission
# XXX: codegen passes don't seem available in the new pass manager yet
@dispose pm=ModulePassManager() begin
expand_reductions!(pm)
run!(pm, mod)
if LLVM.has_oldpm()
# XXX: codegen passes don't seem available in the new pass manager yet
@dispose pm=ModulePassManager() begin
expand_reductions!(pm)
run!(pm, mod)
end
end

return functions(mod)[entry_fn]
Expand All @@ -226,7 +229,8 @@ end
# NOTE: this pass also only rewrites pointers _without_ address spaces, which requires it to
# be executed after optimization (where Julia's address spaces are stripped). If we ever
# want to execute it earlier, adapt remapType to rewrite all pointer types.
function add_address_spaces!(@nospecialize(job::CompilerJob), mod::LLVM.Module, f::LLVM.Function)
function add_parameter_address_spaces!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
f::LLVM.Function)
ft = function_type(f)

# find the byref parameters
Expand Down Expand Up @@ -332,6 +336,92 @@ function add_address_spaces!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
return new_f
end

# update address spaces of constant global objects
#
# global constant objects need to reside in address space 2, so we clone each function
# that uses global objects and rewrite the globals used by it
function add_global_address_spaces!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
entry::LLVM.Function)
# determine global variables we need to update
global_map = Dict{LLVM.Value, LLVM.Value}()
for gv in globals(mod)
isconstant(gv) || continue
addrspace(value_type(gv)) == 0 || continue

gv_ty = global_value_type(gv)
gv_name = LLVM.name(gv)

LLVM.name!(gv, gv_name * ".old")
new_gv = GlobalVariable(mod, gv_ty, gv_name, 2)

alignment!(new_gv, alignment(gv))
unnamed_addr!(new_gv, unnamed_addr(gv))
initializer!(new_gv, initializer(gv))
constant!(new_gv, true)
linkage!(new_gv, linkage(gv))
visibility!(new_gv, visibility(gv))

# we can't map the global variable directly, as the type change won't be applied
# recursively. so instead map a constant expression converting the value of the
# global into one with the old address space, avoiding a type change.
ptr = const_addrspacecast(new_gv, value_type(gv))

global_map[gv] = ptr
end
isempty(global_map) && return entry

# determine which functions we need to update
function_worklist = Set{LLVM.Function}()
function check_user(val)
if val isa LLVM.Instruction
bb = LLVM.parent(val)
f = LLVM.parent(bb)

push!(function_worklist, f)
elseif val isa LLVM.ConstantExpr
for use in uses(val)
check_user(user(use))
end
end
end
for gv in keys(global_map), use in uses(gv)
check_user(user(use))
end

# update functions that use the global
if !isempty(function_worklist)
entry_fn = LLVM.name(entry)
for fun in function_worklist
fn = LLVM.name(fun)

new_fun = clone(fun; value_map=global_map)
replace_uses!(fun, new_fun)
replace_metadata_uses!(fun, new_fun)
erase!(fun)

LLVM.name!(new_fun, fn)
end
entry = LLVM.functions(mod)[entry_fn]
end

# delete old globals
for (old, new) in global_map
for use in uses(old)
val = user(use)
if val isa ConstantExpr
# XXX: shouldn't clone_into! remove unused CEs?
isempty(uses(val)) || error("old function still has uses (via a constant expr)")
LLVM.unsafe_destroy!(val)
end
end
@assert isempty(uses(old))
replace_metadata_uses!(old, new)
erase!(old)
end

return entry
end


# value-to-reference conversion
#
Expand Down
18 changes: 17 additions & 1 deletion test/metal_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ end
declare void @llvm.va_start(i8*)
declare void @llvm.va_end(i8*)
declare void @air.os_log(i8*, i64)
define void @metal_os_log(...) {
%1 = alloca i8*
%2 = bitcast i8** %1 to i8*
Expand Down Expand Up @@ -126,6 +126,22 @@ end
end
end

@testset "constant globals" begin
mod = @eval module $(gensym())
const xs = (1.0f0, 2f0)

function kernel(ptr, i)
unsafe_store!(ptr, xs[i])

return
end
end

ir = sprint(io->Metal.code_llvm(io, mod.kernel, Tuple{Core.LLVMPtr{Float32,1}, Int};
dump_module=true, kernel=true))
@test occursin("addrspace(2) constant [2 x float]", ir)
end

end

end
Expand Down

0 comments on commit a849e8a

Please sign in to comment.