diff --git a/src/metal.jl b/src/metal.jl index cd8c6f3f..79f9df6e 100644 --- a/src/metal.jl +++ b/src/metal.jl @@ -162,7 +162,8 @@ function finish_ir!(@nospecialize(job::CompilerJob{MetalCompilerTarget}), mod::L # add kernel metadata if job.config.kernel - entry = add_address_spaces!(job, mod, entry) + entry = add_parameter_address_spaces!(job, mod, entry) + entry = add_global_address_spaces!(job, mod, entry) add_argument_metadata!(job, mod, entry) @@ -199,10 +200,12 @@ function finish_ir!(@nospecialize(job::CompilerJob{MetalCompilerTarget}), mod::L end # perform codegen passes that would normally run during machine code emission - # XXX: codegen passes don't seem available in the new pass manager yet - @dispose pm=ModulePassManager() begin - expand_reductions!(pm) - run!(pm, mod) + if LLVM.has_oldpm() + # XXX: codegen passes don't seem available in the new pass manager yet + @dispose pm=ModulePassManager() begin + expand_reductions!(pm) + run!(pm, mod) + end end return functions(mod)[entry_fn] @@ -226,7 +229,8 @@ end # NOTE: this pass also only rewrites pointers _without_ address spaces, which requires it to # be executed after optimization (where Julia's address spaces are stripped). If we ever # want to execute it earlier, adapt remapType to rewrite all pointer types. -function add_address_spaces!(@nospecialize(job::CompilerJob), mod::LLVM.Module, f::LLVM.Function) +function add_parameter_address_spaces!(@nospecialize(job::CompilerJob), mod::LLVM.Module, + f::LLVM.Function) ft = function_type(f) # find the byref parameters @@ -332,6 +336,92 @@ function add_address_spaces!(@nospecialize(job::CompilerJob), mod::LLVM.Module, return new_f end +# update address spaces of constant global objects +# +# global constant objects need to reside in address space 2, so we clone each function +# that uses global objects and rewrite the globals used by it +function add_global_address_spaces!(@nospecialize(job::CompilerJob), mod::LLVM.Module, + entry::LLVM.Function) + # determine global variables we need to update + global_map = Dict{LLVM.Value, LLVM.Value}() + for gv in globals(mod) + isconstant(gv) || continue + addrspace(value_type(gv)) == 0 || continue + + gv_ty = global_value_type(gv) + gv_name = LLVM.name(gv) + + LLVM.name!(gv, gv_name * ".old") + new_gv = GlobalVariable(mod, gv_ty, gv_name, 2) + + alignment!(new_gv, alignment(gv)) + unnamed_addr!(new_gv, unnamed_addr(gv)) + initializer!(new_gv, initializer(gv)) + constant!(new_gv, true) + linkage!(new_gv, linkage(gv)) + visibility!(new_gv, visibility(gv)) + + # we can't map the global variable directly, as the type change won't be applied + # recursively. so instead map a constant expression converting the value of the + # global into one with the old address space, avoiding a type change. + ptr = const_addrspacecast(new_gv, value_type(gv)) + + global_map[gv] = ptr + end + isempty(global_map) && return entry + + # determine which functions we need to update + function_worklist = Set{LLVM.Function}() + function check_user(val) + if val isa LLVM.Instruction + bb = LLVM.parent(val) + f = LLVM.parent(bb) + + push!(function_worklist, f) + elseif val isa LLVM.ConstantExpr + for use in uses(val) + check_user(user(use)) + end + end + end + for gv in keys(global_map), use in uses(gv) + check_user(user(use)) + end + + # update functions that use the global + if !isempty(function_worklist) + entry_fn = LLVM.name(entry) + for fun in function_worklist + fn = LLVM.name(fun) + + new_fun = clone(fun; value_map=global_map) + replace_uses!(fun, new_fun) + replace_metadata_uses!(fun, new_fun) + erase!(fun) + + LLVM.name!(new_fun, fn) + end + entry = LLVM.functions(mod)[entry_fn] + end + + # delete old globals + for (old, new) in global_map + for use in uses(old) + val = user(use) + if val isa ConstantExpr + # XXX: shouldn't clone_into! remove unused CEs? + isempty(uses(val)) || error("old function still has uses (via a constant expr)") + LLVM.unsafe_destroy!(val) + end + end + @assert isempty(uses(old)) + replace_metadata_uses!(old, new) + erase!(old) + end + + return entry +end + # value-to-reference conversion # diff --git a/test/metal_tests.jl b/test/metal_tests.jl index de97d90b..113868aa 100644 --- a/test/metal_tests.jl +++ b/test/metal_tests.jl @@ -89,7 +89,7 @@ end declare void @llvm.va_start(i8*) declare void @llvm.va_end(i8*) declare void @air.os_log(i8*, i64) - + define void @metal_os_log(...) { %1 = alloca i8* %2 = bitcast i8** %1 to i8* @@ -126,6 +126,22 @@ end end end +@testset "constant globals" begin + mod = @eval module $(gensym()) + const xs = (1.0f0, 2f0) + + function kernel(ptr, i) + unsafe_store!(ptr, xs[i]) + + return + end + end + + ir = sprint(io->Metal.code_llvm(io, mod.kernel, Tuple{Core.LLVMPtr{Float32,1}, Int}; + dump_module=true, kernel=true)) + @test occursin("addrspace(2) constant [2 x float]", ir) +end + end end