Skip to content

Commit

Permalink
Tests...
Browse files Browse the repository at this point in the history
  • Loading branch information
knyazer committed Nov 26, 2023
1 parent 483f252 commit f082c6e
Show file tree
Hide file tree
Showing 8 changed files with 159 additions and 38 deletions.
19 changes: 9 additions & 10 deletions cotix/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from jaxtyping import install_import_hook
# from jaxtyping import install_import_hook


__all__ = [
Expand All @@ -12,12 +12,11 @@
"utils",
]

with install_import_hook("cotix", "beartype.beartype"):
import cotix._abstract_shapes as abstract_shapes
import cotix._bodies as bodies
import cotix._collisions as collisions
import cotix._convex_shapes as convex_shapes
import cotix._design_by_contract as design_by_contract
import cotix._geometry_utils as geometry_utils
import cotix._universal_shape as universal_shape
import cotix._utils as utils
import cotix._abstract_shapes as abstract_shapes
import cotix._bodies as bodies
import cotix._collisions as collisions
import cotix._convex_shapes as convex_shapes
import cotix._design_by_contract as design_by_contract
import cotix._geometry_utils as geometry_utils
import cotix._universal_shape as universal_shape
import cotix._utils as utils
18 changes: 12 additions & 6 deletions cotix/_bodies.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,19 +63,25 @@ def set_inertia(self, inertia: Float[Array, ""]):
"""Sets inertia tensor (2d vector) of the body."""
return eqx.tree_at(lambda x: x.inertia, self, inertia)

def set_position(self, position: Float[Array, "2"]):
def set_position(self, position: Float[Array, "2"], update_transform=True):
"""Sets position of the center of mass of the body."""
tmp = eqx.tree_at(lambda x: x.position, self, position)
return tmp.update_transform()
if update_transform:
return tmp.update_transform()
else:
return tmp

def set_velocity(self, velocity: Float[Array, "2"]):
"""Sets velocity of the center of mass of the body."""
return eqx.tree_at(lambda x: x.velocity, self, velocity)

def set_angle(self, angle: Float[Array, ""]):
def set_angle(self, angle: Float[Array, ""], update_transform=True):
"""Sets the angle of rotation around its center of mass."""
tmp = eqx.tree_at(lambda x: x.angle, self, angle)
return tmp.update_transform()
if update_transform:
return tmp.update_transform()
else:
return tmp

def set_angular_velocity(self, angular_velocity: Float[Array, ""]):
"""Sets the rate of change of bodys angle."""
Expand Down Expand Up @@ -105,9 +111,9 @@ def load(self, other):
return (
self.set_mass(other.mass)
.set_inertia(other.inertia)
.set_position(other.position)
.set_position(other.position, False)
.set_velocity(other.velocity)
.set_angle(other.angle)
.set_angle(other.angle, False)
.set_angular_velocity(other.angular_velocity)
.set_elasticity(other.elasticity)
.set_friction_coefficient(other.friction_coefficient)
Expand Down
21 changes: 13 additions & 8 deletions cotix/_colliders.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ def resolve(self, bodies):
continue
for a in body.shape.parts:
for b in body2.shape.parts:
a = a.transform(body.shape._transformer)
b = b.transform(body2.shape._transformer)

type1 = type(a)
type2 = type(b)
# make sure that the order is the same as in _contact_funcs
Expand Down Expand Up @@ -139,15 +142,17 @@ def body1_loop(body1):

# apply every update recorded in contact_points
new_all_contact_points, _ = eqx.internal.scan(
lambda data_arr, i: jtu.tree_map(
lambda leaf, to_set: leaf.at[
current_contacts[0][i], current_contacts[1][i]
].set(to_set[i]),
data_arr,
current_contacts[2],
is_leaf=eqx.is_array,
lambda data_arr, i: (
jtu.tree_map(
lambda leaf, to_set: leaf.at[
current_contacts[0][i], current_contacts[1][i]
].set(to_set[i]),
data_arr,
current_contacts[2],
is_leaf=eqx.is_array,
),
None,
),
None,
init=all_contacts,
xs=jnp.arange(len(current_contacts)),
kind="lax",
Expand Down
20 changes: 20 additions & 0 deletions cotix/_convex_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ def get_center(self):
def move(self, delta: Float[Array, "2"]):
return eqx.tree_at(lambda x: x.position, self, self.position + delta)

def transform(self, transformer):
return Circle(
radius=self.radius,
position=self.position + transformer.shift(),
)


class AABB(AbstractConvexShape, strict=True):
"""
Expand Down Expand Up @@ -88,6 +94,12 @@ def move(self, delta: Float[Array, "2"]):
new_self = eqx.tree_at(lambda x: x.lower, new_self, new_self.lower + delta)
return new_self

def transform(self, transformer):
return AABB(
lower=self.lower + transformer.shift(),
upper=self.upper + transformer.shift(),
)


class Polygon(AbstractConvexShape, strict=True):
"""
Expand Down Expand Up @@ -130,3 +142,11 @@ def contains(self, point):
def move(self, delta: Float[Array, "2"]):
new_vertices = jax.lax.map(lambda x: x + delta, self.vertices)
return Polygon(new_vertices)

def transform(self, transformer):
return Polygon(
vertices=jax.lax.map(
lambda x: x + transformer.shift(),
self.vertices,
),
)
6 changes: 5 additions & 1 deletion cotix/_geometry_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def __init__(self, position=jnp.array([0.0, 0.0]), angle=jnp.array(0.0)):
]
)

self.inv_matrix = jnp.linalg.pinv(self.matrix)
self.inv_matrix = jnp.linalg.inv(self.matrix)

def inverse_direction(self, x):
"""Direction: from global coordinate system to local."""
Expand All @@ -126,3 +126,7 @@ def forward_vector(self, x):
homo_dir = jnp.array([x[0], x[1], 1.0])
transformed = self.matrix @ homo_dir
return jnp.array([transformed[0], transformed[1]]) / transformed[2]

def shift(self):
"""Get the shift that is applied when moving from local to global"""
return self.matrix[2, :2]
8 changes: 0 additions & 8 deletions cotix/_physics_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,6 @@ def _single_body_step(self, body: AbstractBody, dt: Float[Array, ""]):
return new_body

def step(self, bodies, dt=jnp.nan):
dt = eqx.error_if(
dt,
jnp.isnan(dt),
"You must provide dt; if you want to use "
"adaptive step size - don't. If you have no idea what value "
"to put as dt, put 1e-3: probs will be good enough.",
)

new_bodies = jtu.tree_map(
lambda body: self._single_body_step(body, dt=dt),
bodies,
Expand Down
61 changes: 56 additions & 5 deletions test/test_collider.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import equinox as eqx
import jax
from jax import numpy as jnp, random as jr

from cotix._bodies import AnyBody
from cotix._colliders import NaiveCollider
from cotix._convex_shapes import AABB, Circle
from cotix._physics_solvers import ExplicitEulerPhysics
from cotix._universal_shape import UniversalShape


Expand Down Expand Up @@ -40,14 +40,18 @@ def test_simple_world():

collider = NaiveCollider()

eqx.filter_jit(collider.resolve)(bodies)
new_bodies = collider.resolve(bodies)

assert True
# check that positions are sufficiently different
assert jnp.linalg.norm(bodies[0].position - new_bodies[0].position) > 0.1
assert jnp.linalg.norm(bodies[1].position - new_bodies[1].position) > 0.1


def test_a_huge_chunk_of_balls():
# just test compilation speed: if this can be compiled less than a minute,
# good enough ig :)
balls = []
for i in range(40):
for i in range(20):
balls.append(
AnyBody(
position=jnp.zeros((2,)) + 1e-1,
Expand All @@ -64,12 +68,59 @@ def test_a_huge_chunk_of_balls():
)
)

jax.config.update("jax_log_compiles", True)
collider = NaiveCollider()
eqx.filter_jit(collider.resolve)(balls)
assert True


def test_two_ball_long():
# two balls move towards each other, and collide
# let's check that at some point they are moving in the opposite direction

a = AnyBody(
position=jnp.array([-1.86, 0.0]),
velocity=jnp.array([1.0, 0.0]),
elasticity=jnp.array(1.0),
shape=UniversalShape(
Circle(
position=jnp.zeros(
2,
),
radius=jnp.array(1.0),
)
),
)
b = AnyBody(
position=jnp.array([2.784, 0.0]),
velocity=jnp.array([-1.51, 0.0]),
elasticity=jnp.array(1.0),
shape=UniversalShape(
Circle(
position=jnp.zeros(
2,
),
radius=jnp.array(1.0),
)
),
)

bodies = [a, b]
physics_solver = ExplicitEulerPhysics()
collider = NaiveCollider()

for i in range(100):
bodies, _ = eqx.filter_jit(physics_solver.step)(bodies, dt=1e-1)
bodies = eqx.filter_jit(collider.resolve)(bodies)

# check that positions are sufficiently different
assert jnp.linalg.norm(bodies[0].position - bodies[1].position) > 5.0
# check that first ball is moving to the left
assert bodies[0].velocity[0] < -0.8
# check that second ball is moving to the right
assert bodies[1].velocity[0] > 0.8


if __name__ == "__main__":
test_simple_world()
test_a_huge_chunk_of_balls()
test_two_ball_long()
44 changes: 44 additions & 0 deletions test/test_physics_solvers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from jax import numpy as jnp

from cotix._bodies import AnyBody
from cotix._convex_shapes import Circle
from cotix._physics_solvers import ExplicitEulerPhysics
from cotix._universal_shape import UniversalShape


def test_simple_world():
# creates a ball, moving to the right, and checks that it actually is moving
# to the right
a = AnyBody(
position=jnp.zeros((2,)) + 1e-1,
velocity=jnp.array([1.0, 0.0]),
shape=UniversalShape(
Circle(
position=jnp.zeros(
2,
),
radius=jnp.array(1.0),
)
),
)

bodies = [a]

solver = ExplicitEulerPhysics()

for i in range(100):
bodies, _ = solver.step(bodies, dt=1e-1)

assert bodies[0].position[0] > 2e-1
assert (bodies[0].position[1] < 1e-1 + 1e-2) & (bodies[0].position[1] > 1e-1 - 1e-2)
assert (bodies[0].shape.parts[0].position[0] < 1e-2) & (
bodies[0].shape.parts[0].position[0] > -1e-2
)
assert (bodies[0].shape.parts[0].position[1] < 1e-2) & (
bodies[0].shape.parts[0].position[1] > -1e-2
)
assert jnp.linalg.norm(bodies[0].velocity[0] - 1.0) < 1e-2


if __name__ == "__main__":
test_simple_world()

0 comments on commit f082c6e

Please sign in to comment.