Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Search & Rescue Multi-Agent Environment #259

Open
wants to merge 25 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
3597d9e
feat: Implement predator prey env (#1)
zombie-einstein Nov 4, 2024
c955320
Merge branch 'instadeepai:main' into main
zombie-einstein Nov 4, 2024
6b34657
Merge branch 'main' into main
sash-a Nov 4, 2024
988339b
fix: PR fixes (#2)
zombie-einstein Nov 5, 2024
a0fe7a5
Merge branch 'instadeepai:main' into main
zombie-einstein Nov 5, 2024
b4cce01
style: Run updated pre-commit
zombie-einstein Nov 6, 2024
cb6d88d
refactor: Consolidate predator prey type
zombie-einstein Nov 7, 2024
06de3a0
feat: Implement search and rescue (#3)
zombie-einstein Nov 11, 2024
34beab6
fix: PR fixes (#4)
zombie-einstein Nov 14, 2024
f5fa659
Merge branch 'instadeepai:main' into main
zombie-einstein Nov 15, 2024
072db18
refactor: PR fixes (#5)
zombie-einstein Nov 19, 2024
162a74d
feat: Allow variable environment dimensions (#6)
zombie-einstein Nov 19, 2024
4996869
Merge branch 'main' into main
zombie-einstein Nov 22, 2024
6322f61
fix: Locate targets in single pass (#8)
zombie-einstein Nov 23, 2024
4ba7688
Merge branch 'instadeepai:main' into main
zombie-einstein Nov 28, 2024
9a654b9
feat: training and customisable observations (#7)
zombie-einstein Dec 7, 2024
5021e20
feat: view all targets (#9)
zombie-einstein Dec 9, 2024
c5c7b85
Merge branch 'instadeepai:main' into main
zombie-einstein Dec 9, 2024
13ffb84
Merge branch 'instadeepai:main' into main
zombie-einstein Dec 9, 2024
9e8ac5c
feat: Scaled rewards and target velocities (#10)
zombie-einstein Dec 11, 2024
5c509c7
Pass shape information to timesteps (#11)
zombie-einstein Dec 11, 2024
8acf242
test: extend tests and docs (#12)
zombie-einstein Dec 11, 2024
1792aa6
fix: unpin jax requirement
zombie-einstein Dec 12, 2024
1e66e78
Include agent positions in observation (#13)
zombie-einstein Dec 12, 2024
407ff79
Upgrade Esquilax and remove unused random keys (#14)
zombie-einstein Dec 27, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions docs/environments/predator_prey.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Predator-Prey Flock Environment

[//]: # (TODO: Add animated plot)

Environment modelling two competing flocks/swarms of agents:

- Predator agents are rewarded for contacting prey agents, or for proximity to prey agents.
- Prey agents are conversely penalised for being contacted by, or for proximity to predators.

Each set of agents can consist of multiple agents, each independently
updated, and with their own independent observations. The agents occupy a square
space with periodic boundary conditions. Agents have a limited view range, i.e. they
only partially observe their local environment (and the locations of neighbouring agents within
range). Rewards are also assigned individually to each agent dependent on their local state.

## Observation

Each agent generates an independent observation, an array of values
representing the distance along a ray from the agent to the nearest neighbour, with
each cell representing a ray angle (with `num_vision` rays evenly distributed over the agents
field of vision). Prey and prey agent types are visualised independently to allow agents
to observe both local position and type.

- `predators`: jax array (float) of shape `(num_predators, 2 * num_vision)` in the unit interval.
- `prey`: jax array (float) of shape `(num_prey, 2 * num_vision)` in the unit interval.

## Action

Agents can update their velocity each step by rotating and accelerating/decelerating. Values
are clipped to the range `[-1, 1]` and then scaled by max rotation and acceleration
parameters. Agents are restricted to velocities within a fixed range of speeds.

- `predators`: jax array (float) of shape (num_predators, 2) each corresponding to `[rotation, acceleration]`.
- `prey`: jax array (float) of shape (num_prey, 2) each corresponding to `[rotation, acceleration]`.

## Reward

Rewards can be either sparse or proximity-based.

### Sparse

- `predators`: jax array (float) of shape `(num_predators,)`, predators are rewarded a fixed amount
for coming into contact with a prey agent. If they are in contact with multiple prey, only the
nearest agent is selected.
- `prey`: jax array (float) of shape `(num_predators,)`, prey are penalised a fix negative amount if
they come into contact with a predator agent.

### Proximity

- `predators`: jax array (float) of shape `(num_predators,)`, predators are rewarded with an amount
scaled linearly with the distance to the prey agents, summed over agents in range.
- `prey`: jax array (float) of shape `(num_predators,)`, prey are penalised by an amount scaled linearly
with distance from predator agents, summed over all predators in range.
1 change: 1 addition & 0 deletions jumanji/environments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
from jumanji.environments.routing.snake.env import Snake
from jumanji.environments.routing.sokoban.env import Sokoban
from jumanji.environments.routing.tsp.env import TSP
from jumanji.environments.swarms.predator_prey import PredatorPrey


def is_colab() -> bool:
Expand Down
13 changes: 13 additions & 0 deletions jumanji/environments/swarms/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2022 InstaDeep Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
13 changes: 13 additions & 0 deletions jumanji/environments/swarms/common/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2022 InstaDeep Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
53 changes: 53 additions & 0 deletions jumanji/environments/swarms/common/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright 2022 InstaDeep Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from dataclasses import dataclass
else:
from chex import dataclass

import chex


@dataclass
zombie-einstein marked this conversation as resolved.
Show resolved Hide resolved
class AgentParams:
"""
max_rotate: Max angle an agent can rotate during a step (a fraction of pi)
max_accelerate: Max change in speed during a step
min_speed: Minimum agent speed
max_speed: Maximum agent speed
view_angle: Agent view angle, as a fraction of pi either side of its heading
"""

max_rotate: float
max_accelerate: float
min_speed: float
max_speed: float
view_angle: float


@dataclass
class AgentState:
"""
State of multiple agents of a single type

pos: 2d position of the (centre of the) agents
heading: Heading of the agents (in radians)
speed: Speed of the agents
"""

pos: chex.Array
heading: chex.Array
speed: chex.Array
zombie-einstein marked this conversation as resolved.
Show resolved Hide resolved
183 changes: 183 additions & 0 deletions jumanji/environments/swarms/common/updates.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
# Copyright 2022 InstaDeep Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Tuple

import chex
import esquilax
import jax
import jax.numpy as jnp

from . import types
from .types import AgentParams
zombie-einstein marked this conversation as resolved.
Show resolved Hide resolved


@esquilax.transforms.amap
def update_velocity(
_: chex.PRNGKey,
params: types.AgentParams,
x: Tuple[chex.Array, types.AgentState],
) -> Tuple[float, float]:
"""
Get the updated agent heading and speeds from actions

Args:
_: Dummy JAX random key.
params: Agent parameters.
x: Agent rotation and acceleration actions.

Returns:
float: New agent heading.
float: New agent speed.
"""
actions, boid = x
rotation = actions[0] * params.max_rotate * jnp.pi
acceleration = actions[1] * params.max_accelerate

new_heading = (boid.heading + rotation) % (2 * jnp.pi)
new_speeds = jnp.clip(
boid.speed + acceleration,
min=params.min_speed,
max=params.max_speed,
)

return new_heading, new_speeds


@esquilax.transforms.amap
zombie-einstein marked this conversation as resolved.
Show resolved Hide resolved
def move(
_: chex.PRNGKey, _params: None, x: Tuple[chex.Array, float, float]
) -> chex.Array:
"""
Get updated agent positions from current speed and heading

Args:
_: Dummy JAX random key.
_params: unused parameters.
x: Tuple containing current agent position, heading and speed.

Returns:
jax array (float32): Updated agent position
"""
pos, heading, speed = x
d_pos = jnp.array([speed * jnp.cos(heading), speed * jnp.sin(heading)])
return (pos + d_pos) % 1.0


def init_state(
n: int, params: types.AgentParams, key: chex.PRNGKey
) -> types.AgentState:
"""
Randomly initialise state of a group of agents

Args:
n: Number of agents to initialise.
params: Agent parameters.
key: JAX random key.

Returns:
AgentState: Random agent states (i.e. position, headings, and speeds)
"""
k1, k2, k3 = jax.random.split(key, 3)

positions = jax.random.uniform(k1, (n, 2))
speeds = jax.random.uniform(
k2, (n,), minval=params.min_speed, maxval=params.max_speed
)
headings = jax.random.uniform(k3, (n,), minval=0.0, maxval=2.0 * jax.numpy.pi)

return types.AgentState(
pos=positions,
speed=speeds,
heading=headings,
)
zombie-einstein marked this conversation as resolved.
Show resolved Hide resolved


def update_state(
key: chex.PRNGKey, params: AgentParams, state: types.AgentState, actions: chex.Array
) -> types.AgentState:
"""
Update the state of a group of agents from a sample of actions

Args:
key: Dummy JAX random key.
params: Agent parameters.
state: Current agent states.
actions: Agent actions, i.e. a 2D array of action for each agent.

Returns:
AgentState: Updated state of the agents after applying steering
actions and updating positions.
"""
actions = jax.numpy.clip(actions, min=-1.0, max=1.0)
zombie-einstein marked this conversation as resolved.
Show resolved Hide resolved
headings, speeds = update_velocity(key, params, (actions, state))
positions = move(key, None, (state.pos, headings, speeds))

return types.AgentState(
pos=positions,
speed=speeds,
heading=headings,
)


def view(
_: chex.PRNGKey,
zombie-einstein marked this conversation as resolved.
Show resolved Hide resolved
params: Tuple[float, float],
sash-a marked this conversation as resolved.
Show resolved Hide resolved
a: types.AgentState,
b: types.AgentState,
zombie-einstein marked this conversation as resolved.
Show resolved Hide resolved
*,
n_view: int,
i_range: float,
) -> chex.Array:
"""
Simple agent view model

Simple view model where the agents view angle is subdivided
into an array of values representing the distance from
the agent along a rays from the agent, with rays evenly distributed.
across the agents field of view. The limit of vision is set at 1.0,
which is also the default value if no object is within range.
Currently, this model assumes the viewed objects are circular.

Args:
_: Dummy JAX random key.
params: Tuple containing agent view angle and view-radius.
a: Viewing agent state.
b: State of agent being viewed.
n_view: Static number of view rays/subdivisions (i.e. how
many cells the resulting array contains).
i_range: Static agent view/interaction range.

Returns:
jax array (float32): 1D array representing the distance
along a ray from the agent to another agent.
"""
view_angle, radius = params
rays = jnp.linspace(
-view_angle * jnp.pi,
view_angle * jnp.pi,
n_view,
endpoint=True,
)
dx = esquilax.utils.shortest_vector(a.pos, b.pos)
d = jnp.sqrt(jnp.sum(dx * dx)) / i_range
phi = jnp.arctan2(dx[1], dx[0]) % (2 * jnp.pi)
dh = esquilax.utils.shortest_vector(phi, a.heading, 2 * jnp.pi)

angular_width = jnp.arctan2(radius, d)
left = dh - angular_width
right = dh + angular_width

obs = jnp.where(jnp.logical_and(left < rays, rays < right), d, 1.0)
return obs
Loading