Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
guilhermeleobas committed May 2, 2023
1 parent b4c00a0 commit f027248
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 20 deletions.
17 changes: 12 additions & 5 deletions rbc/heavydb/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,19 +45,26 @@ def deepcopy(self, context, builder, val, retptr):
ptr_type = self.dtype.members[0]
element_size = int64_t(ptr_type.dtype.bitwidth // 8)

struct_load = builder.load(val, name='struct_load')
src = builder.extract_value(struct_load, 0, name='array_buff_ptr')
element_count = builder.extract_value(struct_load, 1, name='array_size')
is_null = builder.extract_value(struct_load, 2, name='array_is_null')

zero, one, two = int32_t(0), int32_t(1), int32_t(2)
# cgutils.printf(builder, "[deepcopy] val=[%p]\n", val)
src = builder.load(builder.gep(val, [zero, zero]), name='array_buff_ptr')
element_count = builder.load(builder.gep(val, [zero, one]), name='array_size')
is_null = builder.load(builder.gep(val, [zero, two]), name='array_is_null')
# struct_load = builder.load(val, name='struct_load')
# src = builder.extract_value(struct_load, 0, name='array_buff_ptr')
# element_count = builder.extract_value(struct_load, 1, name='array_size')
# is_null = builder.extract_value(struct_load, 2, name='array_is_null')
# cgutils.printf(builder, "[deepcopy] Array=[%p, %d, %d]\n", src, element_count, is_null)

with builder.if_else(cgutils.is_true(builder, is_null)) as (then, otherwise):
with then:
# cgutils.printf(builder, "[deepcopy] Null (is_null=%d) - element_count: %d\n", is_null, element_count)
nullptr = cgutils.get_null_value(src.type)
builder.store(nullptr, builder.gep(retptr, [zero, zero]))
with otherwise:
# we can't just copy the pointer here because return buffers need
# to have their own memory, as input buffers are freed upon returning
# cgutils.printf(builder, "[deepcopy] Not null (null=%d) - element_count: %d\n", is_null, element_count)
dst = memalloc(context, builder, ptr_type, element_count, element_size)
cgutils.raw_memcpy(builder, dst, src, element_count, element_size)
builder.store(dst, builder.gep(retptr, [zero, zero]))
Expand Down
26 changes: 17 additions & 9 deletions rbc/heavydb/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ def heavydb_buffer_constructor(context, builder, sig, args):
element_count = builder.zext(args[0], int64_t)
element_size = int64_t(ptr_type.dtype.bitwidth // 8)

# cgutils.printf(builder, "[ctor] element_count: %d - element_size: %d\n", element_count, element_size)
ptr = memalloc(context, builder, ptr_type, element_count, element_size)

llty = context.get_value_type(sig.return_type.dtype)
Expand All @@ -222,6 +223,8 @@ def heavydb_buffer_constructor(context, builder, sig, args):
with orelse:
is_null = context.get_value_type(null_type)(0)
builder.store(is_null, builder.gep(st_ptr, [zero, two]))
# cgutils.printf(builder, "[ctor] Array=[%p, %d, %d]\n", ptr, element_count, is_null)
# cgutils.printf(builder, "[ctor] val=[%p]\n", st_ptr)
return st_ptr


Expand Down Expand Up @@ -296,12 +299,16 @@ def heavydb_buffer_ptr_len_(typingctx, data):
sig = types.int64(data)

def codegen(context, builder, signature, args):
data, = args
rawptr = cgutils.alloca_once_value(builder, value=data)
struct = builder.load(builder.gep(rawptr,
[int32_t(0)]))
return builder.load(builder.gep(
struct, [int32_t(0), int32_t(1)]))
i32 = ir.IntType(32)
zero, one = i32(0), i32(1)
[data] = args
return builder.load(builder.gep(data, [zero, one]))
# data, = args
# rawptr = cgutils.alloca_once_value(builder, value=data)
# struct = builder.load(builder.gep(rawptr,
# [int32_t(0)]))
# return builder.load(builder.gep(
# struct, [int32_t(0), int32_t(1)]))
return sig, codegen


Expand Down Expand Up @@ -371,10 +378,11 @@ def codegen(context, builder, sig, args):

data, index, value = args

rawptr = cgutils.alloca_once_value(builder, value=data)
ptr = builder.load(rawptr)
# breakpoint()
# rawptr = cgutils.alloca_once_value(builder, value=data)
# ptr = builder.load(rawptr)

buf = builder.load(builder.gep(ptr, [zero, zero]))
buf = builder.load(builder.gep(data, [zero, zero]))
# [rbc issue-197] Numba promotes operations like
# int32(a) + int32(b) to int64
fromty = sig.args[2]
Expand Down
8 changes: 4 additions & 4 deletions rbc/irtools.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,8 +438,8 @@ def compile_to_LLVM(functions_and_signatures,
typing_context = JITRemoteTypingContext()
target_context = JITRemoteTargetContext(typing_context)

nrt_module = create_nrt_functions(target_context, debug=debug)
unicodetype_db = read_unicodetype_db()
# nrt_module = create_nrt_functions(target_context, debug=debug)
# unicodetype_db = read_unicodetype_db()

# Bring over Array overloads (a hack):
target_context._defns = target_desc.target_context._defns
Expand All @@ -455,8 +455,8 @@ def compile_to_LLVM(functions_and_signatures,
assert isinstance(user_defined_llvm_ir, llvm.ModuleRef)
main_module.link_in(user_defined_llvm_ir, preserve=True)

main_module.link_in(unicodetype_db)
main_library.add_ir_module(nrt_module)
# main_module.link_in(unicodetype_db)
# main_library.add_ir_module(nrt_module)

succesful_fids = []
function_names = []
Expand Down
2 changes: 1 addition & 1 deletion rbc/stdlib/creation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def _array_api_triu(x, k=0):


@expose.implements('full')
def _impl__full(shape, fill_value, dtype=None):
def _impl_full(shape, fill_value, dtype=None):
"""
Return a new array of given shape and type, filled with fill_value.
"""
Expand Down
18 changes: 17 additions & 1 deletion rbc/tests/heavydb/test_array_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def test_array_methods(heavydb, method, signature, args, expected):

query = 'select np_{method}'.format(**locals()) + \
'(' + ', '.join(map(str, args)) + ')' + \
' from {heavydb.table_name};'.format(**locals())
' from {heavydb.table_name} limit 1;'.format(**locals())

_, result = heavydb.sql_execute(query)
out = list(result)[0]
Expand Down Expand Up @@ -255,3 +255,19 @@ def fn(arr, fill_value):
_, result = heavydb.sql_execute(f'{query} limit 1;')
# if the execution succeed, it means the return array has type specified by fill_value
assert len(list(result)[0][0]) > 0


def test_foo(heavydb):

from rbc.heavydb import Array
from rbc.externals.stdio import printf

@heavydb('int32[](int32)')
def foo(sz):
printf("\n\n")
a = array_api.ones(sz, dtype='int32')
return a

query = f'select foo(123) from {heavydb.table_name} limit 1'
_, result = heavydb.sql_execute(query)
print(list(result))

0 comments on commit f027248

Please sign in to comment.