diff --git a/skfem/assembly/basis/abstract_basis.py b/skfem/assembly/basis/abstract_basis.py index 1a1c7962..579ade03 100644 --- a/skfem/assembly/basis/abstract_basis.py +++ b/skfem/assembly/basis/abstract_basis.py @@ -47,7 +47,8 @@ def __init__(self, intorder: Optional[int] = None, quadrature: Optional[Tuple[ndarray, ndarray]] = None, refdom: Type[Refdom] = Refdom, - dofs: Optional[Dofs] = None): + dofs: Optional[Dofs] = None, + disable_doflocs: bool = False): if mesh.refdom != elem.refdom: raise ValueError("Incompatible Mesh and Element.") @@ -56,17 +57,18 @@ def __init__(self, self.dofs = Dofs(mesh, elem) if dofs is None else dofs # global degree-of-freedom location - try: - doflocs = self.mapping.F(elem.doflocs.T) - self.doflocs = np.zeros((doflocs.shape[0], self.N)) - - # match mapped dofs and global dof numbering - for itr in range(doflocs.shape[0]): - for jtr in range(self.dofs.element_dofs.shape[0]): - self.doflocs[itr, self.dofs.element_dofs[jtr]] =\ - doflocs[itr, :, jtr] - except Exception: - logger.warning("Unable to calculate global DOF locations.") + if not disable_doflocs: + try: + doflocs = self.mapping.F(elem.doflocs.T) + self.doflocs = np.zeros((doflocs.shape[0], self.N)) + + # match mapped dofs and global dof numbering + for itr in range(doflocs.shape[0]): + for jtr in range(self.dofs.element_dofs.shape[0]): + self.doflocs[itr, self.dofs.element_dofs[jtr]] =\ + doflocs[itr, :, jtr] + except Exception: + logger.warning("Unable to calculate global DOF locations.") self.mesh = mesh self.elem = elem diff --git a/skfem/assembly/basis/cell_basis.py b/skfem/assembly/basis/cell_basis.py index 4c4f7076..c4a44729 100644 --- a/skfem/assembly/basis/cell_basis.py +++ b/skfem/assembly/basis/cell_basis.py @@ -46,7 +46,8 @@ def __init__(self, intorder: Optional[int] = None, elements: Optional[Any] = None, quadrature: Optional[Tuple[ndarray, ndarray]] = None, - dofs: Optional[Dofs] = None): + dofs: Optional[Dofs] = None, + disable_doflocs: bool = False): """Combine :class:`~skfem.mesh.Mesh` and :class:`~skfem.element.Element` into a set of precomputed global basis functions. @@ -70,6 +71,10 @@ def __init__(self, Optional tuple of quadrature points and weights. dofs Optional :class:`~skfem.assembly.Dofs` object. + disable_doflocs + If `True`, the computation of global DOF locations is + disabled. This may save memory on large meshes if DOF + locations are not required. """ logger.info("Initializing {}({}, {})".format(type(self).__name__, @@ -83,6 +88,7 @@ def __init__(self, quadrature, mesh.refdom, dofs, + disable_doflocs, ) if elements is None: diff --git a/skfem/assembly/basis/facet_basis.py b/skfem/assembly/basis/facet_basis.py index b4f26bae..bd1c0667 100644 --- a/skfem/assembly/basis/facet_basis.py +++ b/skfem/assembly/basis/facet_basis.py @@ -29,7 +29,8 @@ def __init__(self, quadrature: Optional[Tuple[ndarray, ndarray]] = None, facets: Optional[Any] = None, dofs: Optional[Dofs] = None, - side: int = 0): + side: int = 0, + disable_doflocs: bool = False): """Precomputed global basis on boundary facets. Parameters @@ -51,6 +52,10 @@ def __init__(self, Optional subset of facet indices. dofs Optional :class:`~skfem.assembly.Dofs` object. + disable_doflocs + If `True`, the computation of global DOF locations is + disabled. This may save memory on large meshes if DOF + locations are not required. """ typestr = ("{}({}, {})".format(type(self).__name__, @@ -65,6 +70,7 @@ def __init__(self, quadrature, mesh.brefdom, dofs, + disable_doflocs, ) # by default use boundary facets diff --git a/skfem/assembly/basis/interior_facet_basis.py b/skfem/assembly/basis/interior_facet_basis.py index d25a49b0..633481d8 100644 --- a/skfem/assembly/basis/interior_facet_basis.py +++ b/skfem/assembly/basis/interior_facet_basis.py @@ -25,7 +25,8 @@ def __init__(self, quadrature: Optional[Tuple[ndarray, ndarray]] = None, facets: Optional[Any] = None, dofs: Optional[Dofs] = None, - side: int = 0): + side: int = 0, + disable_doflocs: bool = False): """Precomputed global basis on interior facets.""" if facets is None: @@ -42,4 +43,5 @@ def __init__(self, facets=facets, dofs=dofs, side=side, + disable_doflocs=disable_doflocs, ) diff --git a/skfem/assembly/form/coo_data.py b/skfem/assembly/form/coo_data.py index 10565a3d..19add27a 100644 --- a/skfem/assembly/form/coo_data.py +++ b/skfem/assembly/form/coo_data.py @@ -55,7 +55,6 @@ def tolocal(self, basis=None): if self.local_shape is None: raise NotImplementedError("Cannot build local matrices if " "local_shape is not specified.") - assert len(self.local_shape) == 2 local = np.moveaxis(self.data.reshape(self.local_shape + (-1,), order='C'), -1, 0) diff --git a/tests/test_autodiff.py b/tests/test_autodiff.py index 898b3dc9..b8fb299f 100644 --- a/tests/test_autodiff.py +++ b/tests/test_autodiff.py @@ -1,14 +1,17 @@ import pytest import numpy as np -import jax.numpy as jnp from numpy.testing import (assert_array_almost_equal, assert_almost_equal) -from skfem.experimental.autodiff import NonlinearForm -from skfem.experimental.autodiff.helpers import (grad, dot, - ddot, mul, - div, sym_grad, - transpose, - eye, trace) +try: + import jax.numpy as jnp + from skfem.experimental.autodiff import NonlinearForm + from skfem.experimental.autodiff.helpers import (grad, dot, + ddot, mul, + div, sym_grad, + transpose, + eye, trace) +except: + pass from skfem.assembly import Basis from skfem.mesh import MeshTri, MeshQuad from skfem.element import (ElementTriP1, ElementTriP2, diff --git a/tests/test_basis.py b/tests/test_basis.py index 06e1422d..338e0b88 100644 --- a/tests/test_basis.py +++ b/tests/test_basis.py @@ -10,7 +10,8 @@ from skfem import BilinearForm, LinearForm, asm, solve, condense, projection from skfem.mesh import (Mesh, MeshTri, MeshTet, MeshHex, MeshQuad, MeshLine1, MeshWedge1) -from skfem.assembly import CellBasis, FacetBasis, Dofs, Functional +from skfem.assembly import (CellBasis, FacetBasis, Dofs, Functional, + InteriorFacetBasis) from skfem.mapping import MappingIsoparametric from skfem.element import (ElementVectorH1, ElementTriP2, ElementTriP1, ElementTetP2, ElementHexS2, ElementHex2, @@ -646,3 +647,20 @@ def test_with_elements(): assert basis.mapping == basis_half.mapping assert basis.quadrature == basis_half.quadrature assert all(basis_half.tind == basis.mesh.normalize_elements('a')) + + +def test_disable_doflocs(): + mesh = MeshTri().refined(3) + basis = CellBasis(mesh, ElementTriP1()) + basisd = CellBasis(mesh, ElementTriP1(), disable_doflocs=True) + fbasis = FacetBasis(mesh, ElementTriP1()) + fbasisd = FacetBasis(mesh, ElementTriP1(), disable_doflocs=True) + ifbasis = InteriorFacetBasis(mesh, ElementTriP1()) + ifbasisd = InteriorFacetBasis(mesh, ElementTriP1(), + disable_doflocs=True) + assert not hasattr(fbasisd, 'doflocs') + assert hasattr(fbasis, 'doflocs') + assert not hasattr(basisd, 'doflocs') + assert hasattr(basis, 'doflocs') + assert not hasattr(ifbasisd, 'doflocs') + assert hasattr(ifbasis, 'doflocs')