Skip to content

Commit

Permalink
Add an Array class
Browse files Browse the repository at this point in the history
Curves (and possibly other future entities) that are defined by a large
number of points create a huge bottleneck when transforming.
Each Point is transformed separately and within that every
transformation matrix must be constructed separately and new and new
numpy arrays are created which takes a lot of time.

Array holds the same numpy array with *all* the points and transforms
everything in one go. Improvement on a gear case: 45s > 17s (the whole
case, not just the transforms).
  • Loading branch information
FranzBangar committed Jul 18, 2024
1 parent 4e26ef7 commit 71cc985
Show file tree
Hide file tree
Showing 10 changed files with 132 additions and 56 deletions.
Empty file removed examples/case/case.foam
Empty file.
4 changes: 4 additions & 0 deletions src/classy_blocks/base/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ class PointCreationError(ShapeCreationError):
pass


class ArrayCreationError(ShapeCreationError):
pass


class AnnulusCreationError(ShapeCreationError):
pass

Expand Down
78 changes: 78 additions & 0 deletions src/classy_blocks/construct/array.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from typing import Optional

import numpy as np

from classy_blocks.base.element import ElementBase
from classy_blocks.base.exceptions import ArrayCreationError
from classy_blocks.types import PointListType, PointType, VectorType
from classy_blocks.util import functions as f
from classy_blocks.util.constants import DTYPE

# TODO! Tests


class Array(ElementBase):
def __init__(self, points: PointListType):
"""A list of points ('positions') in 3D space"""
self.points = np.array(points, dtype=DTYPE)

shape = np.shape(self.points)

if shape[1] != 3:
raise ArrayCreationError("Provide a list of points of 3D space!")

if len(self.points) <= 1:
raise ArrayCreationError("Provide at least 2 points in 3D space!")

def translate(self, displacement):
self.points += np.asarray(displacement, dtype=DTYPE)

return self

def rotate(self, angle, axis, origin: Optional[PointType] = None):
if origin is None:
origin = f.vector(0, 0, 0)

axis = np.array(axis)
matrix = f.rotation_matrix(axis, angle)
rotated_points = np.dot(self.points - origin, matrix.T)

self.points = rotated_points + origin

return self

def scale(self, ratio, origin: Optional[PointType] = None):
if origin is None:
origin = f.vector(0, 0, 0)

self.points = origin + (self.points - origin) * ratio

return self

def mirror(self, normal: VectorType, origin: Optional[PointType] = None):
if origin is None:
origin = f.vector(0, 0, 0)

normal = np.array(normal)
matrix = f.mirror_matrix(normal)

self.points -= origin

mirrored_points = np.dot(self.points - origin, matrix.T)
self.points = mirrored_points + origin

return self

@property
def center(self):
return np.average(self.points, axis=0)

@property
def parts(self):
return [self]

def __len__(self):
return len(self.points)

def __getitem__(self, i):
return self.points[i]
28 changes: 10 additions & 18 deletions src/classy_blocks/construct/curves/curve.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import abc
import warnings
from typing import List, Optional, Tuple, Union
from typing import Optional, Tuple, Union

import numpy as np
import scipy.optimize

from classy_blocks.base.element import ElementBase
from classy_blocks.construct.point import Point
from classy_blocks.types import NPPointListType, NPPointType, NPVectorType, ParamCurveFuncType, PointListType, PointType
from classy_blocks.construct.array import Array
from classy_blocks.types import NPPointListType, NPPointType, NPVectorType, ParamCurveFuncType, PointType
from classy_blocks.util import functions as f
from classy_blocks.util.constants import DTYPE, TOL
from classy_blocks.util.constants import TOL


class CurveBase(ElementBase):
Expand Down Expand Up @@ -109,28 +109,20 @@ def get_binormal(self, param: float, delta: float = TOL) -> NPVectorType:
class PointCurveBase(CurveBase):
"""A base object for curves, defined by a list of points"""

array: Array

def _check_param(self, param):
return int(super()._check_param(param))

@staticmethod
def _check_points(points: PointListType) -> List[Point]:
"""Check that provided points are sufficient for a curve"""
points = np.array(points, dtype=DTYPE)
shape = np.shape(points)

if shape[0] < 2:
raise ValueError("Provide at least 2 points that represent a curve")

if shape[1] != 3:
raise ValueError("Provide points in 3D space")

return [Point(p) for p in points]

@property
def center(self):
warnings.warn("Using an approximate default curve center (average)!", stacklevel=2)
return np.average(self.discretize(), axis=0)

@property
def parts(self):
return [self.array]


class FunctionCurveBase(PointCurveBase):
"""A base object for curves, driven by functions"""
Expand Down
11 changes: 6 additions & 5 deletions src/classy_blocks/construct/curves/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import numpy as np

from classy_blocks.construct.array import Array
from classy_blocks.construct.curves.curve import PointCurveBase
from classy_blocks.types import NPPointListType, NPPointType, PointListType, PointType
from classy_blocks.util import functions as f
Expand All @@ -19,8 +20,8 @@ class DiscreteCurve(PointCurveBase):
Length just sums the distances between points."""

def __init__(self, points: PointListType):
self.points = self._check_points(points)
self.bounds = (0, len(self.points) - 1)
self.array = Array(points)
self.bounds = (0, len(self.array) - 1)

def discretize(
self, param_from: Optional[float] = None, param_to: Optional[float] = None, _count: int = 0
Expand All @@ -32,7 +33,7 @@ def discretize(
param_start = int(min(param_from, param_to))
param_end = int(max(param_from, param_to))

discretized = np.array([p.position for p in self.points[param_start : param_end + 1]])
discretized = self.array[param_start : param_end + 1]

if param_from > param_to:
return np.flip(discretized, axis=0)
Expand All @@ -58,8 +59,8 @@ def center(self):
def get_point(self, param: float) -> NPPointType:
param = self._check_param(param)
index = int(param)
return self.points[index].position
return self.array[index]

@property
def parts(self):
return self.points
return [self.array]
9 changes: 5 additions & 4 deletions src/classy_blocks/construct/curves/interpolated.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import numpy as np

from classy_blocks.construct.array import Array
from classy_blocks.construct.curves.curve import FunctionCurveBase
from classy_blocks.construct.curves.interpolators import InterpolatorBase, LinearInterpolator, SplineInterpolator
from classy_blocks.types import PointListType
Expand All @@ -25,14 +26,14 @@ class InterpolatedCurveBase(FunctionCurveBase, abc.ABC):
_interpolator: Type[InterpolatorBase]

def __init__(self, points: PointListType):
self.points = self._check_points(points)
self.function = self._interpolator(self.points, False)
self.array = Array(points)
self.function = self._interpolator(self.array, False)
self.bounds = (0, 1)

@property
def segments(self) -> int:
"""Returns number of points this curve was created from"""
return len(self.points) - 1
return len(self.array) - 1

@property
def parts(self):
Expand All @@ -41,7 +42,7 @@ def parts(self):
# is no longer valid and needs to be rebuilt
self.function.invalidate()

return self.points
return [self.array]

def get_length(self, param_from: Optional[float] = None, param_to: Optional[float] = None) -> float:
"""Returns the length of this curve by summing distance between
Expand Down
15 changes: 5 additions & 10 deletions src/classy_blocks/construct/curves/interpolators.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import abc
from typing import List

import numpy as np
import scipy.interpolate
from numpy.typing import NDArray

from classy_blocks.construct.point import Point
from classy_blocks.types import NPPointListType, NPPointType, ParamCurveFuncType
from classy_blocks.construct.array import Array
from classy_blocks.types import NPPointType, ParamCurveFuncType


class InterpolatorBase(abc.ABC):
Expand All @@ -20,7 +19,7 @@ class InterpolatorBase(abc.ABC):
def _get_function(self) -> ParamCurveFuncType:
"""Returns an interpolation function from stored points"""

def __init__(self, points: List[Point], extrapolate: bool):
def __init__(self, points: Array, extrapolate: bool):
self.points = points
self.extrapolate = extrapolate

Expand All @@ -41,10 +40,6 @@ def invalidate(self) -> None:
def params(self) -> NDArray:
return np.linspace(0, 1, num=len(self.points))

@property
def positions(self) -> NPPointListType:
return np.array([point.position for point in self.points])


class LinearInterpolator(InterpolatorBase):
def _get_function(self):
Expand All @@ -56,14 +51,14 @@ def _get_function(self):
fill_value = np.nan

function = scipy.interpolate.interp1d(
self.params, self.positions, bounds_error=bounds_error, fill_value=fill_value, axis=0 # type: ignore
self.params, self.points.points, bounds_error=bounds_error, fill_value=fill_value, axis=0 # type: ignore
)

return lambda param: function(param)


class SplineInterpolator(InterpolatorBase):
def _get_function(self):
spline = scipy.interpolate.make_interp_spline(self.params, self.positions, check_finite=False)
spline = scipy.interpolate.make_interp_spline(self.params, self.points.points, check_finite=False)

return lambda t: spline(t, extrapolate=self.extrapolate)
20 changes: 11 additions & 9 deletions src/classy_blocks/util/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,18 +216,12 @@ def arc_mid(axis: VectorType, center: PointType, radius: float, point_1: PointTy
return center + unit_vector(sec_ort) * radius


def mirror(point: PointType, normal: VectorType, origin: PointType):
"""Mirror a point around a plane, given by a normal and origin"""
# brainlessly copied from https://gamemath.com/book/matrixtransforms.html
point = np.asarray(point)
normal = unit_vector(normal)
origin = np.asarray(origin)

def mirror_matrix(normal: VectorType):
n_x = normal[0]
n_y = normal[1]
n_z = normal[2]

matrix = np.array(
return np.array(
[
[
1 - 2 * n_x**2,
Expand All @@ -247,8 +241,16 @@ def mirror(point: PointType, normal: VectorType, origin: PointType):
]
)


def mirror(point: PointType, normal: VectorType, origin: PointType):
"""Mirror a point around a plane, given by a normal and origin"""
# brainlessly copied from https://gamemath.com/book/matrixtransforms.html
point = np.asarray(point)
normal = unit_vector(normal)
origin = np.asarray(origin)

point -= origin
rotated = point.dot(matrix)
rotated = point.dot(mirror_matrix(normal))
rotated += origin

return rotated
Expand Down
5 changes: 3 additions & 2 deletions tests/test_construct/test_curves/test_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
from parameterized import parameterized

from classy_blocks.base.exceptions import ArrayCreationError
from classy_blocks.construct.curves.discrete import DiscreteCurve


Expand All @@ -22,12 +23,12 @@ def curve(self) -> DiscreteCurve:

def test_single_point(self):
"""Only one point was provided"""
with self.assertRaises(ValueError):
with self.assertRaises(ArrayCreationError):
_ = DiscreteCurve([[0, 0, 0]])

def test_wrong_shape(self):
"""Points are not in 3-dimensions"""
with self.assertRaises(ValueError):
with self.assertRaises(ArrayCreationError):
_ = DiscreteCurve([[0, 0], [1, 0]])

@parameterized.expand(((-1, 1), (0, 5)))
Expand Down
18 changes: 10 additions & 8 deletions tests/test_construct/test_curves/test_interpolators.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,22 @@
import numpy as np
from parameterized import parameterized

from classy_blocks.construct.array import Array
from classy_blocks.construct.curves.interpolators import LinearInterpolator
from classy_blocks.construct.point import Point


class LinearInterpolatorTests(unittest.TestCase):
def setUp(self):
# a simple square wave
self.points = [
Point([0, 0, 0]),
Point([0, 1, 0]),
Point([1, 1, 0]),
Point([1, 0, 0]),
Point([2, 0, 0]),
]
self.points = Array(
[
[0, 0, 0],
[0, 1, 0],
[1, 1, 0],
[1, 0, 0],
[2, 0, 0],
]
)

@parameterized.expand(
(
Expand Down

0 comments on commit 71cc985

Please sign in to comment.