Skip to content

Commit

Permalink
Add PDO data collection campaign management service
Browse files Browse the repository at this point in the history
  • Loading branch information
Ram81 committed Mar 11, 2024
1 parent 513ec93 commit 9149751
Show file tree
Hide file tree
Showing 8 changed files with 147 additions and 2 deletions.
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ repos:
- types-Pillow
- types-tqdm
- types-PyYAML
- types-requests

- repo: https://github.com/kynan/nbstripout
rev: 0.5.0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ habitat_baselines:
num_pool_agents_per_type: [1, 1]
policy:


habitat_hitl:
window:
title: "Rearrange"
Expand All @@ -35,3 +34,9 @@ habitat_hitl:
hide_humanoid_in_gui: True
camera:
first_person_mode: True
campaign_server:
url: "http://localhost:22362"
endpoints:
initialize_task: "api/v0/initialize_task"
update_task: "api/v0/update_task"
end_task: "api/v0/end_task"
22 changes: 21 additions & 1 deletion examples/hitl/rearrange_v2/rearrange_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
from habitat_hitl._internal.networking.average_rate_tracker import (
AverageRateTracker,
)
from habitat_hitl.app_states.app_service import AppService
from habitat_hitl.app_states.app_state_abc import AppState
from habitat_hitl.app_states.campaign_service import TaskStatus
from habitat_hitl.core.client_helper import ClientHelper
from habitat_hitl.core.gui_input import GuiInput
from habitat_hitl.core.hitl_main import hitl_main
Expand All @@ -40,7 +42,7 @@ class AppStateRearrangeV2(AppState):
Todo
"""

def __init__(self, app_service):
def __init__(self, app_service: AppService):
self._app_service = app_service
self._gui_agent_controllers = self._app_service.gui_agent_controllers
self._num_users = len(self._gui_agent_controllers)
Expand Down Expand Up @@ -77,6 +79,9 @@ def __init__(self, app_service):
self._sps_tracker = AverageRateTracker(2.0)

self._task_instruction = ""
self._num_episodes_completed = 0

self._app_service.campaign_service.initialize_session()

# needed to avoid spurious mypy attr-defined errors
@staticmethod
Expand Down Expand Up @@ -350,6 +355,7 @@ def _check_change_episode(self):
):
return

self._num_episodes_completed += 1
if self._app_service.episode_helper.next_episode_exists():
self._app_service.end_episode(do_reset=True)

Expand Down Expand Up @@ -433,6 +439,20 @@ def sim_update(self, dt, post_sim_update_dict):

self._update_help_text()

if (
self._num_episodes_completed
> self._app_service.campaign_service.max_episodes_per_session
):
self.end_task()

def end_task(self):
self._app_service.campaign_service.end_task(
{
"task_status": TaskStatus.COMPLETED.value,
**self._app_service.campaign_service.session_meta,
}
)


@hydra.main(
version_base=None, config_path="config", config_name="rearrange_v2"
Expand Down
8 changes: 8 additions & 0 deletions habitat-hitl/habitat_hitl/_internal/hitl_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
)
from habitat_hitl.app_states.app_service import AppService
from habitat_hitl.app_states.app_state_abc import AppState
from habitat_hitl.app_states.campaign_service import CampaignService
from habitat_hitl.core.client_message_manager import ClientMessageManager
from habitat_hitl.core.hydra_utils import omegaconf_to_object
from habitat_hitl.core.remote_gui_input import RemoteGuiInput
Expand Down Expand Up @@ -191,6 +192,12 @@ def local_end_episode(do_reset=False):
for controller in gui_agent_controllers:
controller.line_render = line_render

campaign_service = CampaignService(
hitl_config=self._hitl_config,
get_metrics=lambda: self._get_recent_metrics(),
episode_helper=self._episode_helper,
)

self._app_service = AppService(
config=config,
hitl_config=self._hitl_config,
Expand All @@ -209,6 +216,7 @@ def local_end_episode(do_reset=False):
episode_helper=self._episode_helper,
client_message_manager=self._client_message_manager,
gui_agent_controllers=gui_agent_controllers,
campaign_service=campaign_service,
)

self._app_state: AppState = None
Expand Down
7 changes: 7 additions & 0 deletions habitat-hitl/habitat_hitl/app_states/app_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from habitat import Env
from habitat.tasks.rearrange.rearrange_sim import RearrangeSim
from habitat_hitl.app_states.campaign_service import CampaignService
from habitat_hitl.core.client_message_manager import ClientMessageManager
from habitat_hitl.core.gui_input import GuiInput
from habitat_hitl.core.remote_gui_input import RemoteGuiInput
Expand Down Expand Up @@ -42,6 +43,7 @@ def __init__(
episode_helper: EpisodeHelper,
client_message_manager: ClientMessageManager,
gui_agent_controllers: List[GuiController],
campaign_service: CampaignService,
):
self._config = config
self._hitl_config = hitl_config
Expand All @@ -60,6 +62,7 @@ def __init__(
self._episode_helper = episode_helper
self._client_message_manager = client_message_manager
self._gui_agent_controllers = gui_agent_controllers
self._campaign_service = campaign_service

@property
def config(self):
Expand Down Expand Up @@ -128,3 +131,7 @@ def client_message_manager(self) -> ClientMessageManager:
@property
def gui_agent_controllers(self) -> List[GuiController]:
return self._gui_agent_controllers

@property
def campaign_service(self) -> CampaignService:
return self._campaign_service
94 changes: 94 additions & 0 deletions habitat-hitl/habitat_hitl/app_states/campaign_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
#!/usr/bin/env python3

# Copyright (c) Meta Platforms, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

import json
import random
import string
from enum import Enum
from typing import Any, Callable, Dict

import requests

from habitat_hitl.environment.episode_helper import EpisodeHelper


class TaskStatus(Enum):
INIT = "initialized"
IN_PROGRESS = "in_progress"
COMPLETED = "completed"
FAILED = "failed"


class CampaignService:
def __init__(
self,
*,
hitl_config,
get_metrics: Callable,
episode_helper: EpisodeHelper,
):
self._hitl_config = hitl_config
self._get_metrics = get_metrics
self._episode_helper = episode_helper

self.server_url = self._hitl_config.campaign_server.url
self.session_meta: Dict[str, Any] = {}

def post(self, url, data):
response = requests.post(url, data=json.dumps(data))
return response

def get(self, url):
response = requests.get(url)
return response

@staticmethod
def random_id(max_len: int = 10):
return "".join(
random.choice(string.ascii_uppercase + string.digits)
for _ in range(max_len)
)

def initialize_session(self):
self.session_meta["session_id"] = CampaignService.random_id()
self.session_meta["worker_id"] = CampaignService.random_id()
self.session_meta["mode"] = "sandbox"
response = self.initialize_task(
{
"scene_id": "dummy",
"episode_id": 0,
"task_status": TaskStatus.INIT.value,
}
)
if response.status_code == 200:
self.session_meta.update(response.json()["data"])

def set_task_status(self, status: TaskStatus):
endpoint = self._hitl_config.campaign_server.endpoints.update_task
url = f"{self.server_url}/{endpoint}"

response = self.post(url, {"status": status, **self.session_meta})
return response

def initialize_task(self, data: Dict[str, Any]):
endpoint = self._hitl_config.campaign_server.endpoints.initialize_task
url = f"{self.server_url}/{endpoint}"

response = self.post(url, {"data": data, **self.session_meta})
return response

def end_task(self, data: Dict[str, Any]):
endpoint = self._hitl_config.campaign_server.endpoints.end_task
url = f"{self.server_url}/{endpoint}"

response = self.post(url, {"data": data, **self.session_meta})
return response

@property
def max_episodes_per_session(self):
return self.session_meta.get("max_episodes_per_session", 100)
9 changes: 9 additions & 0 deletions habitat-hitl/habitat_hitl/config/hitl_defaults.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,12 @@ habitat_hitl:

# Save the gfx-replay keyframes to file. Use save_filepath_base to specify the filepath base. Gfx-replay files can be used elsewhere in Habitat, e.g. https://github.com/facebookresearch/habitat-lab/pull/1041. Capturing ends (is saved) when the session is over (pressed ESC). The file will be saved as `my_output/my_session.gfx_replay.json.gz`.
save_gfx_replay_keyframes: False

max_episodes_per_session: 100

campaign_server:
url: "http://localhost:22362"
endpoints:
initialize_task: "/initialize_task"
update_task: "/update_task"
end_task: "/end_task"
1 change: 1 addition & 0 deletions habitat-hitl/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
websockets
aiohttp
hydra-core
requests

0 comments on commit 9149751

Please sign in to comment.