Skip to content

Commit

Permalink
Merge branch 'refactor/somecleanup' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
giadarol committed Aug 13, 2022
2 parents 6a3f346 + 3f50e8e commit f06a33c
Show file tree
Hide file tree
Showing 10 changed files with 161 additions and 85 deletions.
12 changes: 6 additions & 6 deletions examples/ex_unionref_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,26 +10,26 @@ class Triangle(xo.Struct):
b = xo.Float64
h = xo.Float64

_extra_c_source = """
_extra_c_sources = ["""
/*gpufun*/
double Triangle_compute_area(Triangle tr, double scale){
double b = Triangle_get_b(tr);
double h = Triangle_get_h(tr);
return 0.5*b*h*scale;
}
"""
"""]


class Square(xo.Struct):
a = xo.Float64

_extra_c_source = """
_extra_c_sources = ["""
/*gpufun*/
double Square_compute_area(Square sq, double scale){
double a = Square_get_a(sq);
return a*a*scale;
}
"""
"""]


class Base(xo.UnionRef):
Expand All @@ -48,15 +48,15 @@ class Prism(xo.Struct):
height = xo.Float64
volume = xo.Float64

_extra_c_source = """
_extra_c_sources = ["""
/*gpukern*/
void Prism_compute_volume(Prism pr){
Base base = Prism_getp_base(pr);
double height = Prism_get_height(pr);
double base_area = Base_compute_area(base, 3.);
Prism_set_volume(pr, base_area*height);
}
"""
"""]


context = xo.ContextCpu()
Expand Down
6 changes: 3 additions & 3 deletions tests/test_capi.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,16 +246,16 @@ def test_dependencies():

class A(xo.Struct):
a=xo.Float64[:]
_extra_c_source="//blah blah A"
_extra_c_sources=["//blah blah A"]

class C(xo.Struct):
c=xo.Float64[:]
_extra_c_source=" //blah blah C"
_extra_c_sources=[" //blah blah C"]

class B(xo.Struct):
b=A
c=xo.Float64[:]
_extra_c_source=" //blah blah B"
_extra_c_sources=[" //blah blah B"]
_depends_on=[C]

assert xo.context.sort_classes([B])[1:]==[A,C,B]
Expand Down
52 changes: 26 additions & 26 deletions tests/test_ref.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ class MyStruct2(xo.Struct):

ms2 = MyStruct2(_buffer=ms._buffer, sr=ms, a=[0, 0, 0])

src = """
src = r"""
/*gpukern*/
void cp_sra_to_a(MyStruct2 ms, int64_t n){
Expand Down Expand Up @@ -194,25 +194,25 @@ class Triangle(xo.Struct):
b = xo.Float64
h = xo.Float64

_extra_c_source = """
/*gpufun*/
double Triangle_compute_area(Triangle tr, double scale){
double b = Triangle_get_b(tr);
double h = Triangle_get_h(tr);
return 0.5*b*h*scale;
}
"""
_extra_c_sources = ["""
/*gpufun*/
double Triangle_compute_area(Triangle tr, double scale){
double b = Triangle_get_b(tr);
double h = Triangle_get_h(tr);
return 0.5*b*h*scale;
}
"""]

class Square(xo.Struct):
a = xo.Float64

_extra_c_source = """
/*gpufun*/
double Square_compute_area(Square sq, double scale){
double a = Square_get_a(sq);
return a*a*scale;
}
"""
_extra_c_sources = ["""
/*gpufun*/
double Square_compute_area(Square sq, double scale){
double a = Square_get_a(sq);
return a*a*scale;
}
"""]

class Base(xo.UnionRef):
_reftypes = (Triangle, Square)
Expand All @@ -229,16 +229,16 @@ class Prism(xo.Struct):
height = xo.Float64
volume = xo.Float64

_extra_c_source = """
/*gpukern*/
void Prism_compute_volume(Prism pr){
Base base = Prism_getp_base(pr);
double height = Prism_get_height(pr);
double base_area = Base_compute_area(base, 3.);
printf("base_area = %e", base_area);
Prism_set_volume(pr, base_area*height);
}
"""
_extra_c_sources = ["""
/*gpukern*/
void Prism_compute_volume(Prism pr){
Base base = Prism_getp_base(pr);
double height = Prism_get_height(pr);
double base_area = Base_compute_area(base, 3.);
printf("base_area = %e", base_area);
Prism_set_volume(pr, base_area*height);
}
"""]

for context in xo.context.get_test_contexts():
print(f"Test {context}")
Expand Down
2 changes: 1 addition & 1 deletion xobjects/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,6 @@

from .typeutils import context_default, get_a_buffer

from .hybrid_class import JEncoder, HybridClass, MetaHybridClass
from .hybrid_class import JEncoder, HybridClass, MetaHybridClass, ThisClass

from .linkedarray import BypassLinked
17 changes: 14 additions & 3 deletions xobjects/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ def sources_from_classes(classes):
sources = []
for cls in classes:
sources.append(cls._gen_c_api())
if hasattr(cls, "_extra_c_source"):
sources.append(cls._extra_c_source)
if hasattr(cls, "_extra_c_sources"):
sources.extend(cls._extra_c_sources)
return sources


Expand All @@ -93,10 +93,13 @@ def classes_from_kernels(kernels):
return classes


def _concatenate_sources(sources):
def _concatenate_sources(sources, apply_to_source=()):
source = []
folders = set()
for ss in sources:
if isinstance(ss, Source):
ss = ss.source

if hasattr(ss, "read"):
source.append(ss.read())
folders.add(os.path.dirname(ss.name))
Expand All @@ -110,6 +113,9 @@ def _concatenate_sources(sources):

folders = [str(ff) for ff in folders]

for ff in apply_to_source:
source = ff(source)

return source, folders


Expand Down Expand Up @@ -166,6 +172,7 @@ def add_kernels(
sources: list,
kernels: dict,
specialize: bool,
apply_to_source: list,
save_source_as: str,
):
pass
Expand Down Expand Up @@ -411,6 +418,10 @@ def get_classes(self):
classes.append(self.ret.atype)
return classes

class Source:
def __init__(self, source, name=None):
self.source = source
self.name = name

class Method:
def __init__(self, args, c_name, ret):
Expand Down
3 changes: 2 additions & 1 deletion xobjects/context_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def add_kernels(
sources=[],
kernels=[],
specialize=True,
apply_to_source=(),
save_source_as=None,
extra_compile_args=["-O3", "-Wno-unused-function"],
extra_link_args=["-O3"],
Expand Down Expand Up @@ -203,7 +204,7 @@ def add_kernels(

sources = headers + cls_sources + sources

source, folders = _concatenate_sources(sources)
source, folders = _concatenate_sources(sources, apply_to_source)

if specialize:
if self.omp_num_threads > 0:
Expand Down
3 changes: 2 additions & 1 deletion xobjects/context_cupy.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,7 @@ def add_kernels(
sources=[],
kernels=[],
specialize=True,
apply_to_source=(),
save_source_as=None,
extra_cdef=None,
extra_classes=[],
Expand Down Expand Up @@ -473,7 +474,7 @@ def add_kernels(

sources = headers + cls_sources + sources

source, folders = _concatenate_sources(sources)
source, folders = _concatenate_sources(sources, apply_to_source)
source = "\n".join(['extern "C"{', source, "}"])

if specialize:
Expand Down
3 changes: 2 additions & 1 deletion xobjects/context_pyopencl.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ def add_kernels(
sources=[],
kernels=[],
specialize=True,
apply_to_source=(),
save_source_as=None,
extra_cdef=None,
extra_classes=[],
Expand Down Expand Up @@ -242,7 +243,7 @@ def add_kernels(

sources = headers + cls_sources + sources

source, folders = _concatenate_sources(sources)
source, folders = _concatenate_sources(sources, apply_to_source)

if specialize:
# included files are searched in the same folders od the src_filed
Expand Down
Loading

0 comments on commit f06a33c

Please sign in to comment.