Skip to content

Commit

Permalink
Merge pull request #234 from Datura-ai/main
Browse files Browse the repository at this point in the history
deploy
  • Loading branch information
pyon12 authored Feb 5, 2025
2 parents ccc5e51 + 7b1a432 commit d195e53
Show file tree
Hide file tree
Showing 9 changed files with 386 additions and 119 deletions.
80 changes: 78 additions & 2 deletions neurons/validators/src/clients/compute_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
DuplicateExecutorsRequest,
ExecutorSpecRequest,
LogStreamRequest,
ResetVerifiedJobRequest,
RentedMachineRequest,
)
from pydantic import BaseModel
Expand All @@ -36,8 +37,9 @@
from services.redis_service import (
DUPLICATED_MACHINE_SET,
MACHINE_SPEC_CHANNEL_NAME,
RENTED_MACHINE_SET,
RENTED_MACHINE_PREFIX,
STREAMING_LOG_CHANNEL,
RESET_VERIFIED_JOB_CHANNEL,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -118,10 +120,12 @@ async def run_forever(self) -> NoReturn:
# subscribe to channel to get machine specs
pubsub = await self.miner_service.redis_service.subscribe(MACHINE_SPEC_CHANNEL_NAME)
log_channel = await self.miner_service.redis_service.subscribe(STREAMING_LOG_CHANNEL)
reset_verified_job_channel = await self.miner_service.redis_service.subscribe(RESET_VERIFIED_JOB_CHANNEL)

# send machine specs to facilitator
self.specs_task = asyncio.create_task(self.wait_for_specs(pubsub))
asyncio.create_task(self.wait_for_log_streams(log_channel))
asyncio.create_task(self.wait_for_reset_verified_job(reset_verified_job_channel))
except Exception as exc:
logger.error(
_m("redis connection error", extra={**self.logging_extra, "error": str(exc)})
Expand Down Expand Up @@ -366,6 +370,78 @@ async def wait_for_log_streams(self, channel: aioredis.client.PubSub):
except TimeoutError:
pass

async def wait_for_reset_verified_job(self, channel: aioredis.client.PubSub):
logs_queue: list[ResetVerifiedJobRequest] = []
while True:
validator_hotkey = self.my_hotkey()
logger.info(
_m(
f"Waiting for clear verified jobs: {validator_hotkey}",
extra=self.logging_extra,
)
)
try:
msg = await channel.get_message(ignore_subscribe_messages=True, timeout=100 * 60)
if msg is None:
logger.warning(
_m(
"No clear job request yet",
extra=self.logging_extra,
)
)
continue

msg = json.loads(msg["data"])
reset_request = None

try:
reset_request = ResetVerifiedJobRequest(
miner_hotkey=msg["miner_hotkey"],
validator_hotkey=validator_hotkey,
executor_uuid=msg["executor_uuid"],
)

logger.info(
_m(
f'Successfully created ResetVerifiedJobRequest instance with {msg}',
extra=self.logging_extra,
)
)
except Exception as exc:
logger.error(
_m(
"Failed to get ResetVerifiedJobRequest instance",
extra={
**self.logging_extra,
"error": str(exc),
"msg": str(msg),
},
),
exc_info=True,
)
continue

logs_queue.append(reset_request)
if self.ws is not None:
while len(logs_queue) > 0:
log_to_send = logs_queue.pop(0)
try:
await self.send_model(log_to_send)
except Exception as exc:
logs_queue.insert(0, log_to_send)
logger.error(
_m(
msg,
extra={
**self.logging_extra,
"error": str(exc),
},
)
)
break
except TimeoutError:
pass

def create_metagraph_refresh_task(self, period=None):
return create_metagraph_refresh_task(period=period)

Expand Down Expand Up @@ -445,7 +521,7 @@ async def handle_message(self, raw_msg: str | bytes):
)

redis_service = self.miner_service.redis_service
await redis_service.delete(RENTED_MACHINE_SET)
await redis_service.delete(RENTED_MACHINE_PREFIX)

for machine in response.machines:
await redis_service.add_rented_machine(machine)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class RentedMachine(BaseModel):
executor_id: str
executor_ip_address: str
executor_ip_port: str
container_name: str


class RentedMachineResponse(BaseModel):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ class RequestType(enum.Enum):
ExecutorSpecRequest = "ExecutorSpecRequest"
RentedMachineRequest = "RentedMachineRequest"
LogStreamRequest = "LogStreamRequest"
ResetVerifiedJobRequest = "ResetVerifiedJobRequest"
DuplicateExecutorsRequest = "DuplicateExecutorsRequest"


Expand Down Expand Up @@ -73,5 +74,12 @@ class LogStreamRequest(BaseValidatorRequest):
logs: list[dict]


class ResetVerifiedJobRequest(BaseValidatorRequest):
message_type: RequestType = RequestType.ResetVerifiedJobRequest
validator_hotkey: str
miner_hotkey: str
executor_uuid: str


class DuplicateExecutorsRequest(BaseValidatorRequest):
message_type: RequestType = RequestType.DuplicateExecutorsRequest
46 changes: 19 additions & 27 deletions neurons/validators/src/services/docker_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,9 +236,9 @@ async def clean_exisiting_containers(
command = f'docker volume prune -af'
await ssh_client.run(command)

async def clear_verified_job_count(self, renting_machine: RentedMachine):
await self.redis_service.remove_pending_pod(renting_machine)
await self.redis_service.clear_verified_job_info(renting_machine.executor_id)
async def clear_verified_job_count(self, miner_hotkey: str, executor_id: str):
await self.redis_service.remove_pending_pod(miner_hotkey, executor_id)
await self.redis_service.clear_verified_job_info(miner_hotkey, executor_id)

async def create_container(
self,
Expand All @@ -258,13 +258,6 @@ async def create_container(
"debug": payload.debug,
}

renting_machine = RentedMachine(
miner_hotkey=payload.miner_hotkey,
executor_id=payload.executor_id,
executor_ip_address=executor_info.address,
executor_ip_port=str(executor_info.port),
)

logger.info(
_m(
"Create Docker Container",
Expand All @@ -290,7 +283,7 @@ async def create_container(
log_text = "No port mappings found"
logger.error(log_text)

await self.clear_verified_job_count(renting_machine)
await self.clear_verified_job_count(payload.miner_hotkey, payload.executor_id)

return FailedContainerRequest(
miner_hotkey=payload.miner_hotkey,
Expand All @@ -300,7 +293,7 @@ async def create_container(
)

# add executor in pending status dict
await self.redis_service.add_pending_pod(renting_machine)
await self.redis_service.add_pending_pod(payload.miner_hotkey, payload.executor_id)

private_key = self.ssh_service.decrypt_payload(keypair.ss58_address, private_key)
pkey = asyncssh.import_private_key(private_key)
Expand Down Expand Up @@ -356,7 +349,7 @@ async def create_container(
logger.error(log_text)

await self.finish_stream_logs()
await self.clear_verified_job_count(renting_machine)
await self.clear_verified_job_count(payload.miner_hotkey, payload.executor_id)

return FailedContainerRequest(
miner_hotkey=payload.miner_hotkey,
Expand Down Expand Up @@ -439,7 +432,7 @@ async def create_container(
logger.error(log_text)

await self.finish_stream_logs()
await self.clear_verified_job_count(renting_machine)
await self.clear_verified_job_count(payload.miner_hotkey, payload.executor_id)

return FailedContainerRequest(
miner_hotkey=payload.miner_hotkey,
Expand Down Expand Up @@ -497,7 +490,7 @@ async def create_container(
logger.error(log_text)

await self.finish_stream_logs()
await self.clear_verified_job_count(renting_machine)
await self.clear_verified_job_count(payload.miner_hotkey, payload.executor_id)

return FailedContainerRequest(
miner_hotkey=payload.miner_hotkey,
Expand All @@ -518,7 +511,7 @@ async def create_container(
logger.error(log_text)

await self.finish_stream_logs()
await self.clear_verified_job_count(renting_machine)
await self.clear_verified_job_count(payload.miner_hotkey, payload.executor_id)

return FailedContainerRequest(
miner_hotkey=payload.miner_hotkey,
Expand All @@ -536,8 +529,14 @@ async def create_container(

await self.finish_stream_logs()

await self.redis_service.add_rented_machine(renting_machine)
await self.redis_service.remove_pending_pod(renting_machine)
await self.redis_service.add_rented_machine(RentedMachine(
miner_hotkey=payload.miner_hotkey,
executor_id=payload.executor_id,
executor_ip_address=executor_info.address,
executor_ip_port=str(executor_info.port),
container_name=container_name,
))
await self.redis_service.remove_pending_pod(payload.miner_hotkey, payload.executor_id)

return ContainerCreatedResult(
container_name=container_name,
Expand All @@ -554,7 +553,7 @@ async def create_container(
logger.error(log_text, exc_info=True)

await self.finish_stream_logs()
await self.clear_verified_job_count(renting_machine)
await self.clear_verified_job_count(payload.miner_hotkey, payload.executor_id)

return FailedContainerRequest(
miner_hotkey=payload.miner_hotkey,
Expand Down Expand Up @@ -699,14 +698,7 @@ async def delete_container(
),
)

await self.redis_service.remove_rented_machine(
RentedMachine(
miner_hotkey=payload.miner_hotkey,
executor_id=payload.executor_id,
executor_ip_address=executor_info.address,
executor_ip_port=str(executor_info.port),
)
)
await self.redis_service.remove_rented_machine(payload.miner_hotkey, payload.executor_id)

async def get_docker_hub_digests(self, repositories) -> dict[str, str]:
"""Retrieve all tags and their corresponding digests from Docker Hub."""
Expand Down
9 changes: 1 addition & 8 deletions neurons/validators/src/services/miner_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,14 +387,7 @@ async def handle_container(self, payload: ContainerBaseRequest):
)
)

await self.redis_service.remove_rented_machine(
RentedMachine(
miner_hotkey=payload.miner_hotkey,
executor_id=payload.executor_id,
executor_ip_address=executor.address if executor else "",
executor_ip_port=str(executor.port if executor else ""),
)
)
await self.redis_service.remove_rented_machine(payload.miner_hotkey, payload.executor_id)

return FailedContainerRequest(
miner_hotkey=payload.miner_hotkey,
Expand Down
Loading

0 comments on commit d195e53

Please sign in to comment.