Skip to content

Commit

Permalink
Purum purum, even more work
Browse files Browse the repository at this point in the history
  • Loading branch information
Roman Knyazhitskiy committed Nov 8, 2023
1 parent 0de6235 commit e2017c9
Show file tree
Hide file tree
Showing 12 changed files with 706 additions and 347 deletions.
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,8 @@
**.vscode
**.all_objects.cache
.work-distro
**/.pytest_cache
**/.ruff_cache
**/.ipynb_checkpoints
venv
**/.idea
2 changes: 2 additions & 0 deletions cotix/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
"convex_shapes",
"design_by_contract",
"universal_shape",
"utils",
]

with install_import_hook("cotix", "beartype.beartype"):
Expand All @@ -19,3 +20,4 @@
import cotix._design_by_contract as design_by_contract
import cotix._geometry_utils as geometry_utils
import cotix._universal_shape as universal_shape
import cotix._utils as utils
3 changes: 3 additions & 0 deletions cotix/_bodies.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ def collides_with(self, other):
"""Returns boolean, whether there is a collision with another body."""
return self.shape.collides_with(other.shape)

def penetrates_with(self, other):
return self.shape.penetrates_with(other.shape)

def possibly_collides_with(self, other):
return self.shape.possibly_collides_with(other.shape)

Expand Down
128 changes: 71 additions & 57 deletions cotix/_colliders.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@

import equinox as eqx
import jax
from jax import numpy as jnp, tree_util as jtu
from jax import numpy as jnp
from jaxtyping import Array, Float, Int

from ._bodies import AbstractBody
from ._collision_resolution import resolve_collision
from ._convex_shapes import AABB
from ._utils import filter_scan
from ._utils import JList, make_pairs


class AbstractCollider(eqx.Module):
Expand Down Expand Up @@ -47,11 +48,24 @@ class _CollisionWithPenetration(eqx.Module):
penetration_vector: Float[Array, "2"]


class _PostCollisionUpdate(eqx.Module):
i: Int[Array, ""]
j: Int[Array, ""]
body_i: AbstractBody
body_j: AbstractBody
@jax.jit
def _check_aabb_collision(a, b):
c = (a[0] < b[0]) & AABB.collides(a[1], b[1])
return jax.lax.cond(
c,
lambda: _BroadCollision(jnp.array(a[0]), jnp.array(b[0])),
lambda: _BroadCollision(jnp.array(-1), jnp.array(-1)),
)


@eqx.filter_jit
def resolve_idk(x, y):
jax.debug.print("{x};{y}", x=x, y=y)


@eqx.filter_jit
def get_body(bodies, i):
return bodies[int(i)]


class NaiveCollider(AbstractCollider):
Expand All @@ -62,57 +76,57 @@ class NaiveCollider(AbstractCollider):
then we have an exact phase, where we detect just N collisions.
"""

def broad_phase(self, bodies, N: int):
res = [_BroadCollision(jnp.array(-1), jnp.array(-1))] * N

def loop_body(carry, xs):
collision_index, res_index, res = carry
i = collision_index // len(bodies)
j = jnp.mod(collision_index, len(bodies))
jax.debug.print("{x}", x=(i, j))
res, res_index = jax.lax.cond(
AABB.collide(AABB(bodies[i]), AABB(bodies[j])),
(
eqx.tree_at(
lambda r: r[res_index], res, replace=_BroadCollision(i, j)
),
res_index + 1,
),
(res[res_index], res_index),
)
return collision_index + 1, res_index

(_, _, res), _ = filter_scan(
loop_body, (jnp.array(0), jnp.array(0), res), None, length=len(bodies) ** 2
@eqx.filter_jit
def broad_phase(self, bodies, limit: int):
# map bodies to theirs aabbs, trace-time
aabbs = [AABB()] * len(bodies)
for i in range(len(bodies)):
aabbs[i] = AABB.of_universal(bodies[i].shape)
out = make_pairs(
aabbs,
_check_aabb_collision,
_BroadCollision(jnp.array(-1), jnp.array(-1)),
limit=limit,
)
return res[:N]

def exact_phase(self, bodies, collision_data, N: int):
pass
return out

def resolve_penetration(self, collision_to_resolve, bodies):
def resolve_collision_of_two_bodies(self, *args):
raise NotImplementedError

i, j = collision_to_resolve.i, collision_to_resolve.j
penetration_vector = collision_to_resolve.penetration_vector
new_body_a, new_body_b = resolve_collision_of_two_bodies(
bodies[i], bodies[j], penetration_vector
)
return _PostCollisionUpdate(i, j, new_body_a, new_body_b)

def resolve(self, bodies: List[AbstractBody]):
broad_collisions = self.broad_phase(bodies, N=4 * len(bodies))
exact_collisions = self._exact_phase(bodies, broad_collisions, N=len(bodies))

updates = jtu.tree_map(
lambda x: self.resolve_penetration(x, bodies),
exact_collisions,
is_leaf=isinstance(AbstractBody),
)
@eqx.filter_jit
def total_phase(self, bodies, limit: int):
@jax.jit
def _check_actual_collision(a, b):
c = a[0] < b[0]
return jax.lax.cond(
c, lambda: a[1].penetrates_with(b[1]), lambda: jnp.zeros((2,))
)

# now 'apply' updates, by setting correct bodies
for upd in updates:
bodies = bodies.at[upd.i].set(upd.body_i)
bodies = bodies.at[upd.j].set(upd.body_j)
return bodies
out = make_pairs(bodies, _check_actual_collision, jnp.zeros((2,)), limit=limit)

return out

def resolve(self, bodies):
initial_bodies = bodies
length = len(bodies)
bodies = JList(bodies)

broad_collisions = self.broad_phase(bodies, N=4 * length)
penetrations = self.narrow_phase(bodies, broad_collisions, N=length)

# yep, the reason we do a for-loop here and not jax tree map
# is so that XLA maybe optimizes all memory shit inside it
# cuz when we do jtu tree map, I think XLA does not need
# extra-iteration optimizations, like it optimizes only inside the iteration
for p in penetrations:
new_body_a, new_body_b = resolve_collision(p.a, p.b, p.penetration_vector)
# TODO: this cool piece of code might cause ~N copies of all the bodies,
# which is like a loooot. Idk how to fix it though, so for now it is fine
# btw, i am not sure if it actually causes them: maybe XLA is smart enough
# to not copy bodies for every instruction (probably it is smart enough)
bodies = bodies.set_at(p.i, new_body_a)
bodies = bodies.set_at(p.j, new_body_b)

out = bodies.to_pytree()
# easy way to check if out and bodies are the same,
# lol -> if the shapes are not the same,
# JAX will commit suicide by itself
return jax.lax.cond(True, lambda: out, lambda: initial_bodies)
8 changes: 7 additions & 1 deletion cotix/_collision_resolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
Implements physics logic that resolves a simple elastic collision between two bodies.
"""


from typing import Tuple

import jax.lax
Expand Down Expand Up @@ -78,3 +77,10 @@ def _1d_elastic_collision_velocities(m1, m2, u1, u2):
v1 = ((m1 - m2) / (m1 + m2)) * u1 + ((2 * m2) / (m1 + m2)) * u2
v2 = ((2 * m1) / (m1 + m2)) * u1 + ((m2 - m1) / (m1 + m2)) * u2
return v1, v2


def resolve_collision(
body1: AbstractBody, body2: AbstractBody, penetration_vector: Float[Array, "2"]
):
"""Cute user-facing abstraction"""
return _resolve_collision_checked(body1, body2, penetration_vector)
146 changes: 146 additions & 0 deletions cotix/_contacts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
import equinox as eqx
import jax
import jax.numpy as jnp
from jaxtyping import Array, Float

from ._convex_shapes import AABB, Circle


class ContactInfo(eqx.Module):
penetration_vector: Float[Array, "2"]
contact_point: Float[Array, "2"]

def __init__(self, penetration_vector, contact_point):
self.penetration_vector = penetration_vector
self.contact_point = contact_point

@staticmethod
def nan():
return ContactInfo(jnp.zeros((2,)), jnp.array([jnp.nan, jnp.nan]))


def circle_vs_circle(a: Circle, b: Circle):
delta = a.position - b.position
distance = jnp.linalg.norm(delta)
direction_between_shapes = jax.lax.cond(
distance == 0.0, lambda: jnp.array([1.0, 0.0]), lambda: delta / distance
)
penetration_vector = direction_between_shapes * jnp.minimum(
distance - (a.radius + b.radius), 0.0
)
contact_point = (
b.position + direction_between_shapes * (b.radius - a.radius) + a.position
) / 2.0
# check that centers lie from different sides of contact point
# and if not, return the center that lies inside another circle
contact_point = jax.lax.cond(
jnp.dot(a.position - contact_point, b.position - contact_point) <= 0,
lambda: contact_point, # different sides
lambda: jax.lax.cond(
a.contains(b.position), # same side
lambda: b.position, # b's center contained in a
lambda: a.position,
), # otherwise
)

return jax.lax.cond(
distance <= a.radius + b.radius,
lambda: ContactInfo(-penetration_vector, contact_point),
lambda: ContactInfo.nan(),
)


def aabb_vs_aabb(a: AABB, b: AABB, eps=1e-8):
is_first_below_second = a.upper[1] <= b.lower[1]
is_first_above_second = a.lower[1] >= b.upper[1]
is_first_left_second = a.upper[0] <= b.lower[0]
is_first_right_second = a.lower[0] >= b.upper[0]

def estimate_contact():
depths = jnp.array(
[
jnp.maximum(
a.upper[1] - b.lower[1], -eps
), # eps here, so that 0 processed correctly
jnp.maximum(b.upper[1] - a.lower[1], -eps),
jnp.maximum(a.upper[0] - b.lower[0], -eps),
jnp.maximum(b.upper[0] - a.lower[0], -eps),
]
)
dirs = jnp.array([[0, -1], [0, 1], [-1, 0], [1, 0]])

index = jnp.argmin(depths)
min_depth = jnp.clip(depths[index], a_min=0.0)
penetration_vector = min_depth * dirs[index]
min_upper = jnp.minimum(a.upper, b.upper)
max_lower = jnp.maximum(a.lower, b.lower)
return ContactInfo(penetration_vector, (min_upper + max_lower) / 2.0)

return jax.lax.cond(
~(
is_first_below_second
| is_first_left_second
| is_first_above_second
| is_first_right_second
),
lambda: estimate_contact(),
lambda: ContactInfo.nan(),
)


def circle_vs_aabb(a: Circle, b: AABB, eps=1e-6):
disp = a.get_center() - b.get_center()
clamp_disp = jnp.clip(disp, b.lower - b.get_center(), b.upper - b.get_center())
ccp = (
b.get_center() + clamp_disp
) # ccp = closest circle point, point on aabb that is closest to the circle
ccp = eqx.error_if(
ccp, ~b.contains(ccp), "Gm, closest point in the AABB is not in AABB. wut?"
)

vs = jnp.array(
[
b.lower,
jnp.array([b.lower[0], b.upper[1]]),
b.upper,
jnp.array([b.upper[0], b.lower[1]]),
]
)

perfect_vertex = jnp.any(jnp.linalg.norm(vs - ccp, axis=1) < eps)

def circle_dir_move():
# move the aabb out of the circle, in the direction
# that is not aligned with axes
dir = ccp - a.position
dir_norm = dir / jnp.linalg.norm(dir) # TODO: check division by zero
return ContactInfo(-(a.position + a.radius * dir_norm - ccp), ccp)

def aligned_move():
# now, we want to move aabb out of the circle moving only in one axis
# while there is a hard way to do this, I am too lazy,
# so I am just gonna try all 4 variants and see which one is the best
dirs = jnp.array([[0, 1], [0, -1], [1, 0], [-1, 0]])
a.position - ccp
d_from_aabb = b.get_center() - ccp

prods = jax.lax.map(lambda x: jnp.dot(x, d_from_aabb), dirs)
jnp.argmax(prods)

shifts = jnp.array(
[
a.position[1] + a.radius - b.lower[1],
b.upper[1] - (a.position[1] - a.radius),
a.position[0] + a.radius - b.lower[0],
b.upper[0] - (a.position[0] - a.radius),
]
)

best_shift = jnp.argmin(shifts)
return ContactInfo(-shifts[best_shift] * dirs[best_shift], ccp)

return jax.lax.cond(
a.contains(ccp),
lambda: jax.lax.cond(perfect_vertex, circle_dir_move, aligned_move),
lambda: ContactInfo.nan(),
)
Loading

0 comments on commit e2017c9

Please sign in to comment.