Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

16 angular momentum and friction #48

Merged
merged 14 commits into from
Nov 8, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 76 additions & 5 deletions cotix/_bodies.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ class AbstractBody(eqx.Module, strict=True):
angle: AbstractVar[Float[Array, ""]]
angular_velocity: AbstractVar[Float[Array, ""]]

elasticity: AbstractVar[Float[Array, ""]]

shape: AbstractVar[UniversalShape]

def update_transform(self):
Expand Down Expand Up @@ -58,15 +60,17 @@ def set_inertia(self, inertia: Float[Array, ""]):

def set_position(self, position: Float[Array, "2"]):
"""Sets position of the center of mass of the body."""
return eqx.tree_at(lambda x: x.position, self, position)
tmp = eqx.tree_at(lambda x: x.position, self, position)
return tmp.update_transform()

def set_velocity(self, velocity: Float[Array, "2"]):
"""Sets velocity of the center of mass of the body."""
return eqx.tree_at(lambda x: x.velocity, self, velocity)

def set_angle(self, angle: Float[Array, ""]):
"""Sets the angle of rotation around its center of mass."""
return eqx.tree_at(lambda x: x.angle, self, angle)
tmp = eqx.tree_at(lambda x: x.angle, self, angle)
return tmp.update_transform()

def set_angular_velocity(self, angular_velocity: Float[Array, ""]):
"""Sets the rate of change of bodys angle."""
Expand All @@ -76,6 +80,13 @@ def set_shape(self, shape: UniversalShape):
"""Replaces the shape with any other shape: not recommended to use."""
return eqx.tree_at(lambda x: x.shape, self, shape)

def get_center_of_mass(self):
return self.position

def get_mass_matrix(self):
# it is a scalar since we are in 2d, so there is 1 axis of rotation
return self.inertia

def __invariant__(self):
return (
# Checks for nans
Expand Down Expand Up @@ -109,19 +120,32 @@ class Ball(AbstractBody, strict=True):
angle: Float[Array, ""]
angular_velocity: Float[Array, ""]

elasticity: Float[Array, ""]

shape: UniversalShape

def __init__(self, mass, velocity, shape):
def __init__(self, mass, position, velocity, shape):
# check that the shape is a circle
if not (isinstance(shape.parts[0], Circle) and len(shape.parts) == 1):
raise ValueError("Ball universal shape must be a circle")

self.mass = mass
self.inertia = mass
self.inertia = (
2 * (mass * shape.parts[0].radius ** 2) / 5
) # inertia of a solid ball

self.position = jnp.zeros((2,))
self.position = position
self.velocity = velocity

self.angle = jnp.array(0.0)
self.angular_velocity = jnp.array(0.0)

self.elasticity = jnp.array(1.0)

self.shape = shape
self.shape = self.shape.update_transform(
angle=self.angle, position=self.position
)

@staticmethod
def make_default():
Expand All @@ -131,6 +155,7 @@ def make_default():
ball = Ball(
jnp.array(1.0),
jnp.zeros((2,)),
jnp.zeros((2,)),
UniversalShape(Circle(jnp.array(0.05), jnp.zeros((2,)))),
)
ball = (
Expand All @@ -143,3 +168,49 @@ def make_default():
.set_shape(UniversalShape(Circle(jnp.array(0.05), jnp.zeros((2,)))))
)
return ball


class AnyBody(AbstractBody, strict=True):
"""
A body with any shape. Useful for tests.
"""

mass: Float[Array, ""]
inertia: Float[Array, ""]

position: Float[Array, "2"]
velocity: Float[Array, "2"]

angle: Float[Array, ""]
angular_velocity: Float[Array, ""]

elasticity: Float[Array, ""]

shape: UniversalShape

def __init__(
self,
mass,
inertia,
position,
velocity,
angle,
angular_velocity,
elasticity,
shape,
):
self.mass = mass
self.inertia = inertia

self.position = position
self.velocity = velocity

self.angle = angle
self.angular_velocity = angular_velocity

self.elasticity = elasticity

self.shape = shape
self.shape = self.shape.update_transform(
angle=self.angle, position=self.position
)
103 changes: 83 additions & 20 deletions cotix/_collision_resolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,17 @@
def _split_bodies(
body1: AbstractBody, body2: AbstractBody, epa_vector: Float[Array, "2"]
) -> Tuple[AbstractBody, AbstractBody]:
# lets apply translation to both bodies, taking their mass into account
# it may be not good idea to split 100% of the penetration,
# but we can change that later
split_portion = 1.0
# let's apply translation to both bodies, taking their mass into account
body1 = body1.set_position(
body1.position + epa_vector * (body2.mass / (body1.mass + body2.mass))
body1.position
+ split_portion * epa_vector * (body2.mass / (body1.mass + body2.mass))
)
body2 = body2.set_position(
body2.position - epa_vector * (body1.mass / (body1.mass + body2.mass))
body2.position
- split_portion * epa_vector * (body1.mass / (body1.mass + body2.mass))
)
return body1, body2

Expand All @@ -39,33 +44,91 @@ def _resolve_collision_checked(
def _resolve_collision(
body1: AbstractBody, body2: AbstractBody, epa_vector: Float[Array, "2"]
) -> Tuple[AbstractBody, AbstractBody]:
# todo: we need to determine the elasticity of the collision.
# probably based on the body properties
# todo: angular momentum

elasticity = 1
elasticity = body1.elasticity * body2.elasticity

# change coordinate system from (x, y) to (q, r)
# where q is the line along epa_vector and r is perpendicular to it
# where [0] is the line along epa_vector and [1] is perpendicular to it
unit_collision_vector = epa_vector / jnp.linalg.norm(epa_vector)
perpendicular = perpendicular_vector(unit_collision_vector)
change_of_basis = jnp.array([unit_collision_vector, perpendicular])
change_of_basis_inv = jnp.linalg.inv(change_of_basis)

v1q = jnp.dot(body1.velocity, unit_collision_vector)
v1r = jnp.dot(body1.velocity, perpendicular) # stays constant
v2q = jnp.dot(body2.velocity, unit_collision_vector)
v2r = jnp.dot(body2.velocity, perpendicular) # stays constant

v1q_new, v2q_new = _1d_elastic_collision_velocities(
body1.mass, body2.mass, v1q, v2q
# everything below should be in the new coordinate system
change_of_basis @ unit_collision_vector
perpendicular_new_basis = change_of_basis @ perpendicular
v1_col_basis = change_of_basis @ body1.velocity
v2_col_basis = change_of_basis @ body2.velocity

v_rel = v1_col_basis - v2_col_basis

# the contact point is set to be exactly between
# furthest (penetrating) points of the bodies
# along the collision direction
# get_global_support doesnt know about the new basis,
# so we use a vector from the old basis
contact_point = (
body1.shape.get_global_support(-unit_collision_vector)
+ body2.shape.get_global_support(unit_collision_vector)
) / 2
contact_point = change_of_basis @ contact_point

relative_contact_point1 = (
contact_point - change_of_basis @ body1.get_center_of_mass()
)
relative_contact_point2 = (
contact_point - change_of_basis @ body2.get_center_of_mass()
)

v1qr_new = jnp.array([v1q_new * elasticity, v1r])
v2qr_new = jnp.array([v2q_new * elasticity, v2r])
lever_arm1 = jnp.dot(relative_contact_point1, perpendicular_new_basis)
lever_arm2 = jnp.dot(relative_contact_point2, perpendicular_new_basis)

# jax.debug.print(
# "\nrelative_contact_points: "
# "{relative_contact_point1}, {relative_contact_point2}. "
# "perpendicular: {perpendicular}, "
# "lever arms: {lever_arm1}, {lever_arm2}. \n"
# "center1: {center1}, center2: {center2}. "
# "contact_point: {contact_point}. \n"
# "global_supports {sup1}, {sup2}. "
# "collision_unit_vector {unit_collision_vector}. \n",
# relative_contact_point1=relative_contact_point1,
# relative_contact_point2=relative_contact_point2,
# perpendicular=perpendicular_new_basis,
# lever_arm1=lever_arm1,
# lever_arm2=lever_arm2,
# center1=body1.get_center_of_mass(),
# center2=body2.get_center_of_mass(),
# contact_point=change_of_basis_inv @ contact_point,
# sup1=body1.shape.get_global_support(-unit_collision_vector),
# sup2=body2.shape.get_global_support(unit_collision_vector),
# unit_collision_vector=unit_collision_vector,
# )

# impulse computation is in accordance with
# https://github.com/knyazer/RSS/blob/
# 1246e03c5950a5549a128fbce97c7bd402f9bed7/engine/source/env/World.cpp#L87

# inertia is kg * m^2, so this is kg^-1
impulseFactor1 = (1 / body1.mass) + (lever_arm1**2) / body1.inertia
impulseFactor2 = (1 / body2.mass) + (lever_arm2**2) / body2.inertia

# this is a vector because v_rel is a vector
# units are kg * m / s
col_impulse = -(1 + elasticity) * v_rel / (impulseFactor1 + impulseFactor2)

v1_new_col_basis = v1_col_basis + col_impulse / body1.mass
v2_new_col_basis = v2_col_basis - col_impulse / body2.mass

# col_impulse[0] is along collision normal
body1 = body1.set_angular_velocity(
body1.angular_velocity + (lever_arm1 * col_impulse[0]) / body1.inertia
)
body2 = body2.set_angular_velocity(
body2.angular_velocity - (lever_arm2 * col_impulse[0]) / body2.inertia
)

v1_new = jnp.matmul(change_of_basis_inv, v1qr_new)
v2_new = jnp.matmul(change_of_basis_inv, v2qr_new)
v1_new = change_of_basis_inv @ v1_new_col_basis
v2_new = change_of_basis_inv @ v2_new_col_basis

body1 = body1.set_velocity(v1_new)
body2 = body2.set_velocity(v2_new)
Expand Down
8 changes: 7 additions & 1 deletion cotix/_geometry_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,13 @@ def perpendicular_vector(v):
return jnp.array([-v[1], v[0]])


class HomogenuousTransformer(eqx.Module, strict=True):
def angle_between(v1, v2):
v1_u = v1 / jnp.linalg.norm(v1)
v2_u = v2 / jnp.linalg.norm(v2)
return jnp.arccos(jnp.clip(jnp.dot(v1_u, v2_u), -1.0, 1.0))


class HomogeneousTransformer(eqx.Module, strict=True):
"""
Allows to apply arbitrary affine transformations to passed vectors/directions
"""
Expand Down
8 changes: 4 additions & 4 deletions cotix/_universal_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from ._abstract_shapes import AbstractShape, SupportFn
from ._collisions import check_for_collision_convex, compute_penetration_vector_convex
from ._geometry_utils import HomogenuousTransformer
from ._geometry_utils import HomogeneousTransformer


class UniversalShape(eqx.Module, strict=True):
Expand All @@ -22,11 +22,11 @@ class UniversalShape(eqx.Module, strict=True):
"""

parts: list[AbstractShape]
_transformer: HomogenuousTransformer
_transformer: HomogeneousTransformer

def __init__(self, *shapes: AbstractShape):
self.parts = [*shapes]
self._transformer = HomogenuousTransformer()
self._transformer = HomogeneousTransformer()

def wrap_local_support(self, support_fn: SupportFn) -> SupportFn:
"""
Expand Down Expand Up @@ -63,7 +63,7 @@ def update_transform(self, angle: Float[Array, ""], position: Float[Array, "2"])
return eqx.tree_at(
lambda x: x._transformer,
self,
HomogenuousTransformer(angle=angle, position=position),
HomogeneousTransformer(angle=angle, position=position),
)

def collides_with(self, other):
Expand Down
Loading
Loading