Skip to content

Commit

Permalink
Merge branch 'feature/check_if_owndata' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
giadarol committed Feb 27, 2022
2 parents a75bd1e + 8f6d524 commit fa80e07
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 1 deletion.
42 changes: 42 additions & 0 deletions tests/test_ref.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,3 +255,45 @@ class Prism(xo.Struct):

assert prism_triangle.volume == 45
assert prism_square.volume == 120

assert prism_triangle._has_refs

def test_has_refs():

class StructWRef(xo.Struct):
a = xo.Ref(xo.Float64[:])
assert StructWRef._has_refs

class StructNoRef(xo.Struct):
a = xo.Float64[:]
assert not StructNoRef._has_refs

class NestedWRef(xo.Struct):
s = StructWRef
assert NestedWRef._has_refs

class NestedNoRef(xo.Struct):
s = StructNoRef
assert not NestedNoRef._has_refs

ArrNoRef = xo.Float64[:]
assert not ArrNoRef._has_refs

ArrWRef = xo.Ref(xo.Float64)[:]
assert ArrWRef._has_refs

class StructArrRef(xo.Struct):
arr = ArrWRef
assert StructArrRef._has_refs

class StructArrNoRef(xo.Struct):
arr = ArrNoRef
assert not StructArrNoRef._has_refs

ArrOfStructRef = NestedWRef[:]
assert ArrOfStructRef._has_refs

class MyUnion(xo.UnionRef):
_ref = [xo.Float64, xo.Int32]
assert MyUnion._has_refs

8 changes: 8 additions & 0 deletions xobjects/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,14 @@ def __new__(cls, name, bases, data):
if "_c_type" not in data:
data["_c_type"] = name

# determine has_refs
if '_itemtype' in data.keys():
if (hasattr(data['_itemtype'], '_has_refs')
and data['_itemtype']._has_refs):
data['_has_refs'] = True
else:
data['_has_refs'] = False

return type.__new__(cls, name, bases, data)

def __getitem__(cls, shape):
Expand Down
5 changes: 5 additions & 0 deletions xobjects/ref.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ def __getitem__(cls, reftype):


class Ref(metaclass=MetaRef):

_has_refs = True

def __init__(self, reftype):
self._reftype = reftype
self.__name__ = "Ref" + self._reftype.__name__
Expand Down Expand Up @@ -102,6 +105,8 @@ def __new__(cls, name, bases, data):
if "_methods" not in data:
data["_methods"] = []

data['_has_refs'] = True

return type.__new__(cls, name, bases, data)

def _is_member(cls, value):
Expand Down
10 changes: 9 additions & 1 deletion xobjects/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
import logging
from typing import Callable, Optional


from .typeutils import (
get_a_buffer,
dispatch_arg,
Expand Down Expand Up @@ -246,6 +245,15 @@ def _inspect_args(cls, *args, **kwargs):
if "_c_type" not in data:
data["_c_type"] = name

# determine owndata
_has_refs = False
for ff in data['_fields']:
ftype = ff.ftype
if hasattr(ftype, '_has_refs') and ftype._has_refs:
_has_refs = True
break
data['_has_refs'] = _has_refs

return type.__new__(cls, name, bases, data)

def __getitem__(cls, shape):
Expand Down

0 comments on commit fa80e07

Please sign in to comment.