Skip to content

Commit

Permalink
Lunar landing: complete, with constraints
Browse files Browse the repository at this point in the history
  • Loading branch information
knyazer committed Feb 26, 2024
1 parent 13c80ef commit a6096b0
Show file tree
Hide file tree
Showing 5 changed files with 165 additions and 46 deletions.
8 changes: 8 additions & 0 deletions cotix/_bodies.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from equinox import AbstractVar
from jaxtyping import Array, Float

from ._geometry_utils import perpendicular_vector
from ._universal_shape import UniversalShape


Expand Down Expand Up @@ -46,6 +47,13 @@ def update_transform(self):
self.shape.update_transform(angle=self.angle, position=self.position),
)

def velocity_at(self, point):
return (
self.velocity
+ perpendicular_vector(point - self.get_center_of_mass())
* self.angular_velocity
)

def collides_with(self, other):
"""Returns boolean, whether there is a collision with another body."""
return self.shape.collides_with(other.shape)
Expand Down
33 changes: 14 additions & 19 deletions cotix/_collision_resolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,14 @@ def resolve_collision(
)


def apply_impulse(body, impulse, point):
arm = point - body.get_center_of_mass()
torque = jnp.cross(arm, impulse)
new_vel = body.velocity + impulse / body.mass
new_ang_vel = body.angular_velocity + torque / body.inertia
return body.set_velocity(new_vel).set_angular_velocity(new_ang_vel)


def resolve_collision_notnan(
body1: AbstractBody, body2: AbstractBody, contact_info: ContactInfo
) -> Tuple[AbstractBody, AbstractBody, CollisionResolutionExtraInfo]:
Expand Down Expand Up @@ -94,7 +102,7 @@ def resolve_collision_notnan(
contact_point_relative_velocity, normal_direction
)

baumgarte_term = 0.1
baumgarte_term = 0.3
elasticity = jnp.minimum(body1.elasticity, body2.elasticity)
r1 = contact_point - body1.get_center_of_mass()
r2 = contact_point - body2.get_center_of_mass()
Expand Down Expand Up @@ -126,30 +134,17 @@ def resolve_collision_notnan(

impulse_vec = impulse_vec + impulse_d_vec

torque1 = jnp.cross(r1, impulse_vec)
torque2 = jnp.cross(r2, impulse_vec)

# apply impulse
new_velocity1 = body1.velocity - impulse_vec / body1.mass
new_velocity2 = body2.velocity + impulse_vec / body2.mass
new_angular_velocity1 = body1.angular_velocity - torque1 / body1.inertia
new_angular_velocity2 = body2.angular_velocity + torque2 / body2.inertia

# condition to apply new impulses:
# if the bodies are moving apart, do nothing
# if the bodies are moving apart, do nothing
cond = jnp.dot(contact_info.penetration_vector, contact_point_relative_velocity) < 0

new_body1 = body1.set_velocity(new_velocity1).set_angular_velocity(
new_angular_velocity1
)
new_body2 = body2.set_velocity(new_velocity2).set_angular_velocity(
new_angular_velocity2
)

new_body1, new_body2 = jax.lax.cond(
cond,
lambda: (body1, body2),
lambda: (new_body1, new_body2),
lambda: (
apply_impulse(body1, -impulse_vec, contact_point),
apply_impulse(body2, impulse_vec, contact_point),
),
)
col = CollisionResolutionExtraInfo.make_default()
col = eqx.tree_at(lambda x: x.contact_point, col, contact_point.astype(jnp.float32))
Expand Down
10 changes: 10 additions & 0 deletions cotix/_geometry_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,16 @@ def angle_between(v1, v2):
return jnp.arccos(jnp.clip(jnp.dot(v1_u, v2_u), -1.0, 1.0))


def rotate(vector, angle_in_rad):
mat = jnp.array(
[
[jnp.cos(angle_in_rad), -jnp.sin(angle_in_rad)],
[jnp.sin(angle_in_rad), jnp.cos(angle_in_rad)],
]
)
return mat @ vector


class HomogenuousTransformer(eqx.Module, strict=True):
"""
Allows to apply arbitrary affine transformations to passed vectors/directions
Expand Down
155 changes: 130 additions & 25 deletions cotix/_lunar_lander.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,34 @@
import equinox as eqx
from jax import numpy as jnp
from jax import numpy as jnp, random as jr

from ._bodies import AnyBody
from ._convex_shapes import AABB, Polygon4, Polygon6
from ._collision_resolution import apply_impulse
from ._convex_shapes import Polygon4, Polygon6
from ._geometry_utils import rotate
from ._universal_shape import UniversalShape


LANDER_POLY = [
(-14, +17),
(-17, 0),
(-17, -10),
(+17, -10),
(+17, 0),
(+14, +17),
]

LEG_AWAY = 24
LEG_DOWN = 8
LEG_W, LEG_H = 2, 8
LEG_ANGLE = -0.3


class LunarLander(eqx.Module, strict=True):
bodies: list[AnyBody]

def __init__(self):
LANDER_POLY = [
(-14, +17),
(-17, 0),
(-17, -10),
(+17, -10),
(+17, 0),
(+14, +17),
]

def __init__(self, key=jr.PRNGKey(0)):
lander_shape = Polygon6(jnp.array(LANDER_POLY) * 0.05)

LEG_AWAY = 24
LEG_DOWN = 10
LEG_W, LEG_H = 2, 8
LEG_ANGLE = -0.3

left_leg_shape = Polygon4(
jnp.array(
[
Expand Down Expand Up @@ -68,13 +71,13 @@ def rotate(vertices, angle):
left_leg_shape.vertices * jnp.array([-1.0, 1.0]),
)

lander_center = jnp.array([0.2, 5.0])
lander_center = jnp.array([0.0, 5.0])
lander_left_leg = jnp.array([LEG_AWAY, -LEG_DOWN]) * 0.05 + lander_center
lander_right_leg = jnp.array([-LEG_AWAY, -LEG_DOWN]) * 0.05 + lander_center

lander = AnyBody(
position=lander_center,
velocity=jnp.array([0.1, -0.8]),
velocity=jnp.array([0.0, 0.0]),
angular_velocity=jnp.array(0.0),
mass=jnp.array(30.0),
inertia=jnp.array(30.0),
Expand All @@ -85,7 +88,7 @@ def rotate(vertices, angle):

right_leg = AnyBody(
position=lander_right_leg,
velocity=jnp.array([0.0, -0.8]),
velocity=jnp.array([0.0, 0.0]),
angular_velocity=jnp.array(0.0),
inertia=jnp.array(1.0),
friction_coefficient=jnp.array(0.1),
Expand All @@ -94,27 +97,129 @@ def rotate(vertices, angle):

left_leg = AnyBody(
position=lander_left_leg,
velocity=jnp.array([0.0, -0.8]),
velocity=jnp.array([0.0, 0.0]),
angular_velocity=jnp.array(0.0),
friction_coefficient=jnp.array(0.1),
inertia=jnp.array(1.0),
shape=UniversalShape(left_leg_shape),
)

# the ground is generated as a bunch of polygon4s that are
# connected to each other, with random-ish heights inbetween 0 and 1
k1, k2, k3, k4, k5 = jr.split(key, 5)
heights = jr.uniform(k1, (8,), minval=-5.0, maxval=5.0)
heights = heights.at[0].set(heights[0] * 10)
heights = heights.at[3].set(-2.0)
heights = heights.at[-4].set(-2.0)
heights = heights.at[-1].set(heights[-1] * 10)

positions = [
-100,
jr.uniform(k2, (), minval=-12.0, maxval=-9.0),
jr.uniform(k3, (), minval=-8.0, maxval=-4.0),
-2,
2,
jr.uniform(k4, (), minval=4.0, maxval=8.0),
jr.uniform(k5, (), minval=9.0, maxval=12.0),
100,
]
polygons = []
for i in range(len(heights) - 1):
p1 = jnp.array([positions[i], heights[i]])
p2 = jnp.array([positions[i], -10])
p3 = jnp.array([positions[i + 1], heights[i + 1]])
p4 = jnp.array([positions[i + 1], -10])
polygons.append(Polygon4(jnp.array([p1, p2, p3, p4])))

ground = AnyBody(
position=jnp.array([0.0, -5.0]),
position=jnp.array([0.0, 0.0]),
mass=jnp.array(jnp.inf),
inertia=jnp.array(jnp.inf),
elasticity=jnp.array(0.1),
friction_coefficient=jnp.array(0.1),
shape=UniversalShape(
AABB(jnp.array([-100.0, -2.0]), jnp.array([100.0, 2.0]))
),
shape=UniversalShape(*polygons),
)

self.bodies = [lander, right_leg, left_leg, ground]

def step(self):
lander_left_joint_1 = jnp.array([LEG_AWAY, -LEG_DOWN]) * 0.05
lander_left_joint_1 = (
rotate(lander_left_joint_1, self.bodies[0].angle) + self.bodies[0].position
)

lander_left_joint_2 = jnp.array([LEG_AWAY, -LEG_DOWN + 8]) * 0.05
lander_left_joint_2 = (
rotate(lander_left_joint_2, self.bodies[0].angle) + self.bodies[0].position
)

left_leg_joint_1 = self.bodies[2].position
left_leg_joint_2 = self.bodies[2].position + rotate(
jnp.array([0.0, 0.4]), self.bodies[2].angle
)

lander_right_joint_1 = jnp.array([-LEG_AWAY, -LEG_DOWN]) * 0.05
lander_right_joint_1 = (
rotate(lander_right_joint_1, self.bodies[0].angle) + self.bodies[0].position
)

lander_right_joint_2 = jnp.array([-LEG_AWAY, -LEG_DOWN + 8]) * 0.05
lander_right_joint_2 = (
rotate(lander_right_joint_2, self.bodies[0].angle) + self.bodies[0].position
)

right_leg_joint_1 = self.bodies[1].position
right_leg_joint_2 = self.bodies[1].position + rotate(
jnp.array([0.0, 0.4]), self.bodies[1].angle
)

def fixed_positional_constraint(
body_and_contact1, body_and_contact2, impulse_fn
):
delta_pos = body_and_contact1[1] - body_and_contact2[1]
delta_vel = body_and_contact1[0].velocity_at(
body_and_contact1[1]
) - body_and_contact2[0].velocity_at(body_and_contact2[1])
impulse = impulse_fn(delta_pos, delta_vel)
b1 = apply_impulse(body_and_contact1[0], -impulse, body_and_contact1[1])
b2 = apply_impulse(body_and_contact2[0], impulse, body_and_contact2[1])
return b1, b2

def impulse_fn(dp, dv):
return dp * 1.0 + dv * (jnp.linalg.norm(dv) + 0.1) * 0.05

new_bodies = self.bodies
lander, right_leg, left_leg = new_bodies[0], new_bodies[1], new_bodies[2]
lander, left_leg = fixed_positional_constraint(
(lander, lander_left_joint_1), (left_leg, left_leg_joint_1), impulse_fn
)
lander, left_leg = fixed_positional_constraint(
(lander, lander_left_joint_2), (left_leg, left_leg_joint_2), impulse_fn
)
lander, right_leg = fixed_positional_constraint(
(lander, lander_right_joint_1), (right_leg, right_leg_joint_1), impulse_fn
)
lander, right_leg = fixed_positional_constraint(
(lander, lander_right_joint_2), (right_leg, right_leg_joint_2), impulse_fn
)

# lets also dump angular velocities of the legs by a lot
right_leg = eqx.tree_at(
lambda x: x.angular_velocity, right_leg, right_leg.angular_velocity * 0.95
)
left_leg = eqx.tree_at(
lambda x: x.angular_velocity, left_leg, left_leg.angular_velocity * 0.95
)

new_bodies = eqx.tree_at(lambda x: x[0], new_bodies, lander)
new_bodies = eqx.tree_at(lambda x: x[1], new_bodies, right_leg)
new_bodies = eqx.tree_at(lambda x: x[2], new_bodies, left_leg)

return eqx.tree_at(lambda cls: cls.bodies, self, new_bodies)

def draw(self, painter):
for body in self.bodies:
body.draw(painter)
painter.draw_line((-2, -1.8), (-2, -1.0), color=(255, 0, 0))
painter.draw_line((2, -1.8), (2, -1.0), color=(255, 0, 0))
painter.next()
5 changes: 3 additions & 2 deletions test/test_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ def test_lunar_lander():

physics = ExplicitEulerPhysics()
collider = RandomizedCollider()
SimpleConstraintSolver(loops=1)
painter = Painter()

@eqx.filter_jit
Expand All @@ -28,7 +27,7 @@ def f(env, key):
new_bodies = eqx.tree_at(
lambda x: x[0].velocity,
new_bodies,
new_bodies[0].velocity + jnp.array([0.0, -0.001]),
new_bodies[0].velocity + jnp.array([0.0, -0.002]),
)

def draw_log(log):
Expand All @@ -39,6 +38,8 @@ def draw_log(log):
# new_bodies = constraintSolver.solve(new_bodies, env.constraints)
key, next_key = jr.split(key)
env = eqx.tree_at(lambda x: x.bodies, env, new_bodies)
env = env.step()

env.draw(painter)
return env, key

Expand Down

0 comments on commit a6096b0

Please sign in to comment.