Skip to content

Commit

Permalink
chore: deserialize status updates
Browse files Browse the repository at this point in the history
  • Loading branch information
mcharytoniuk committed Nov 14, 2024
1 parent a8a5579 commit 4e07b68
Show file tree
Hide file tree
Showing 10 changed files with 130 additions and 53 deletions.
31 changes: 30 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ actix = "0.13.5"
actix-web = "4.9.0"
clap = { version = "4.5.20", features = ["derive"] }
env_logger = "0.11.5"
futures = "0.3.31"
futures-util = { version = "0.3.31", features = ["tokio-io"] }
log = "0.4.22"
reqwest = { version = "0.12.9", features = ["json", "stream"] }
Expand Down
21 changes: 14 additions & 7 deletions src/agent/agent.rs
Original file line number Diff line number Diff line change
@@ -1,27 +1,33 @@
use actix::Addr;
use uuid::Uuid;

use crate::agent::state_reporter::StateReporter;
use crate::balancer::status_update::StatusUpdate;
use crate::errors::result::Result;
use crate::llamacpp::llamacpp_client::LlamacppClient;

pub struct Agent {
id: uuid::Uuid,
external_llamacpp_addr: url::Url,
id: Uuid,
name: Option<String>,
llamacpp_client: LlamacppClient,
}

impl Agent {
pub fn new(llamacpp_client: LlamacppClient, name: Option<String>) -> Self {
pub fn new(
external_llamacpp_addr: url::Url,
llamacpp_client: LlamacppClient,
name: Option<String>,
) -> Self {
Self {
id: uuid::Uuid::new_v4(),
external_llamacpp_addr,
id: Uuid::new_v4(),
name,
llamacpp_client,
}
}

pub async fn observe_and_report(
&self,
state_reporter: &actix::Addr<StateReporter>,
) -> Result<()> {
pub async fn observe_and_report(&self, state_reporter: &Addr<StateReporter>) -> Result<()> {
let status = self.observe().await?;

Ok(state_reporter.send(status).await?)
Expand All @@ -31,6 +37,7 @@ impl Agent {
Ok(StatusUpdate::new(
self.id,
self.name.clone(),
self.external_llamacpp_addr.clone(),
self.llamacpp_client.get_available_slots().await?,
))
}
Expand Down
45 changes: 27 additions & 18 deletions src/agent/state_reporter.rs
Original file line number Diff line number Diff line change
@@ -1,41 +1,50 @@
use actix::{fut::future::WrapFuture, AsyncContext};
use actix::{fut::future::WrapFuture, Actor, AsyncContext, Context, Handler};
use log::error;
use std::sync::Arc;
use tokio::sync::broadcast;
use serde_json::to_vec;
use std::{
sync::{Arc, Mutex},
time::Duration,
};
use tokio::sync::broadcast::{channel, Receiver, Sender};
use tokio_stream::wrappers::BroadcastStream;
use url::Url;

use crate::balancer::status_update::StatusUpdate;
use crate::errors::result::Result;

#[allow(dead_code)]
pub struct StateReporter {
interval_running: std::sync::Arc<std::sync::Mutex<bool>>,
interval_running: Arc<Mutex<bool>>,
stats_endpoint_url: String,
status_update_rx: broadcast::Receiver<actix_web::web::Bytes>,
status_update_tx: Arc<broadcast::Sender<actix_web::web::Bytes>>,
status_update_tx: Arc<Sender<actix_web::web::Bytes>>,

// channel is closed when the initial receiver is dropped
// therefore, we need to keep the reference to the sender
status_update_rx: Receiver<actix_web::web::Bytes>,
}

impl StateReporter {
pub fn new(management_addr: url::Url) -> Result<Self> {
let (tx, rx) = broadcast::channel(1);
pub fn new(management_addr: Url) -> Result<Self> {
let (tx, rx) = channel(1);

Ok(Self {
interval_running: std::sync::Arc::new(std::sync::Mutex::new(false)),
stats_endpoint_url: management_addr.join("/stream")?.to_string(),
interval_running: Arc::new(Mutex::new(false)),
stats_endpoint_url: management_addr.join("/status_update")?.to_string(),
status_update_rx: rx,
status_update_tx: Arc::new(tx),
})
}
}

impl actix::Actor for StateReporter {
type Context = actix::Context<Self>;
impl Actor for StateReporter {
type Context = Context<Self>;

fn started(&mut self, ctx: &mut actix::Context<Self>) {
fn started(&mut self, ctx: &mut Context<Self>) {
let stats_endpoint_url = self.stats_endpoint_url.clone();
let status_update_tx = self.status_update_tx.clone();
let interval_running = self.interval_running.clone();

ctx.run_interval(std::time::Duration::from_secs(1), move |actor, ctx| {
ctx.run_interval(Duration::from_secs(1), move |actor, ctx| {
let interval_running = interval_running.clone();
let stats_endpoint_url = stats_endpoint_url.clone();
let status_update_tx = status_update_tx.clone();
Expand All @@ -59,7 +68,7 @@ impl actix::Actor for StateReporter {
};

let rx = status_update_tx.subscribe();
let stream = tokio_stream::wrappers::BroadcastStream::new(rx);
let stream = BroadcastStream::new(rx);
let reqwest_body = reqwest::Body::wrap_stream(stream);

let result = reqwest::Client::new()
Expand Down Expand Up @@ -95,11 +104,11 @@ impl actix::Actor for StateReporter {
}
}

impl actix::Handler<StatusUpdate> for StateReporter {
impl Handler<StatusUpdate> for StateReporter {
type Result = ();

fn handle(&mut self, msg: StatusUpdate, _ctx: &mut actix::Context<Self>) {
let bytes = match serde_json::to_vec(&msg) {
fn handle(&mut self, msg: StatusUpdate, _ctx: &mut Context<Self>) {
let bytes = match to_vec(&msg) {
Ok(bytes) => bytes,
Err(err) => {
error!("Could not convert status update to bytes: {}", err);
Expand Down
19 changes: 12 additions & 7 deletions src/balancer/http_route/receive_status_update.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,24 @@
use actix_web::{post, web, Error, HttpResponse};
use futures_util::StreamExt as _;

use crate::balancer::status_update::StatusUpdate;

pub fn register(cfg: &mut web::ServiceConfig) {
cfg.service(respond);
}

#[post("/stream")]
#[post("/status_update")]
async fn respond(mut payload: web::Payload) -> Result<HttpResponse, Error> {
println!("Stream started");

while let Some(chunk) = payload.next().await {
println!("Chunk: {:?}", chunk);
match serde_json::from_slice::<StatusUpdate>(&chunk?) {
Ok(status_update) => {
println!("Received status update: {:?}", status_update);
}
Err(e) => {
return Err(Error::from(e));
}
}
}

println!("Stream ended");

Ok(HttpResponse::Ok().finish())
Ok(HttpResponse::Accepted().finish())
}
12 changes: 10 additions & 2 deletions src/balancer/status_update.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,27 @@
use serde::{Deserialize, Serialize};
use uuid::Uuid;

use crate::llamacpp::slot::Slot;

#[derive(Debug, Serialize, Deserialize)]
pub struct StatusUpdate {
agent_id: uuid::Uuid,
agent_id: Uuid,
agent_name: Option<String>,
external_llamacpp_addr: url::Url,
slots: Vec<Slot>,
}

impl StatusUpdate {
pub fn new(agent_id: uuid::Uuid, agent_name: Option<String>, slots: Vec<Slot>) -> Self {
pub fn new(
agent_id: Uuid,
agent_name: Option<String>,
external_llamacpp_addr: url::Url,
slots: Vec<Slot>,
) -> Self {
Self {
agent_id,
agent_name,
external_llamacpp_addr,
slots,
}
}
Expand Down
10 changes: 8 additions & 2 deletions src/cmd/agent.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
use actix::Actor;
use log::error;
use time::{sleep, Duration};
use tokio::time;

use crate::agent::{agent::Agent, state_reporter::StateReporter};
use crate::errors::result::Result;
use crate::llamacpp::llamacpp_client::LlamacppClient;

pub async fn handle(
external_llamacpp_addr: &url::Url,
local_llamacpp_addr: &url::Url,
local_llamacpp_api_key: &Option<String>,
management_addr: &url::Url,
Expand All @@ -15,13 +17,17 @@ pub async fn handle(
let state_reporter_addr = StateReporter::new(management_addr.clone())?.start();
let llamacpp_client =
LlamacppClient::new(local_llamacpp_addr.clone(), local_llamacpp_api_key.clone())?;
let agent = Agent::new(llamacpp_client, name.clone());
let agent = Agent::new(
external_llamacpp_addr.clone(),
llamacpp_client,
name.clone(),
);

loop {
if let Err(err) = agent.observe_and_report(&state_reporter_addr).await {
error!("Unable to connect to llamacpp server: {}", err);
}

time::sleep(time::Duration::from_secs(1)).await;
sleep(Duration::from_secs(1)).await;
}
}
21 changes: 15 additions & 6 deletions src/cmd/balancer.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,22 @@
use actix_web::{App, HttpServer};
use futures::future;
use std::net::SocketAddr;

use crate::balancer::http_route;
use crate::errors::result::Result;

pub async fn handle(_management_addr: &url::Url, _reverseproxy_addr: &url::Url) -> Result<()> {
Ok(
pub async fn handle(management_addr: &SocketAddr, reverseproxy_addr: &SocketAddr) -> Result<()> {
let management_server =
HttpServer::new(move || App::new().configure(http_route::receive_status_update::register))
.bind("127.0.0.1:8095")?
.run()
.await?,
)
.bind(management_addr)?
.run();

let reverseproxy_server =
HttpServer::new(move || App::new().configure(http_route::receive_status_update::register))
.bind(reverseproxy_addr)?
.run();

future::try_join(management_server, reverseproxy_server).await?;

Ok(())
}
3 changes: 2 additions & 1 deletion src/llamacpp/llamacpp_client.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use reqwest::header;
use std::time::Duration;
use url::Url;

use crate::errors::result::Result;
use crate::llamacpp::slot::Slot;
Expand All @@ -10,7 +11,7 @@ pub struct LlamacppClient {
}

impl LlamacppClient {
pub fn new(addr: url::Url, api_key: Option<String>) -> Result<Self> {
pub fn new(addr: Url, api_key: Option<String>) -> Result<Self> {
let mut builder = reqwest::Client::builder().timeout(Duration::from_secs(3));

builder = match api_key {
Expand Down
Loading

0 comments on commit 4e07b68

Please sign in to comment.