diff --git a/src/main.rs b/src/main.rs index 6b02417..8930d22 100644 --- a/src/main.rs +++ b/src/main.rs @@ -8,7 +8,7 @@ use std::{fmt::Display, str::FromStr, sync::OnceLock}; -use axum::{http::header::HeaderValue, response::Response}; +use axum::{http::header::HeaderValue, response::Response, routing::get, Json}; use bytes::Bytes; use deadpool_redis::redis::AsyncCommands; use serde_json::to_value; @@ -62,6 +62,18 @@ async fn get_voices( })) } +async fn get_translation_languages() -> ResponseResult>> { + let state = STATE.get().unwrap(); + let Some(token) = &state.translation_key else { + return Ok(Json(Vec::new())); + }; + + match translation::get_languages(&state.reqwest, token).await { + Ok(languages) => Ok(Json(languages)), + Err(err) => Err(Error::Unknown(err)), + } +} + #[derive(serde::Deserialize)] struct GetTTS { text: FixedString, @@ -342,11 +354,12 @@ async fn main() -> Result<()> { } let app = axum::Router::new() - .route("/tts", axum::routing::get(get_tts)) - .route("/voices", axum::routing::get(get_voices)) + .route("/tts", get(get_tts)) + .route("/voices", get(get_voices)) + .route("/translation_languages", get(get_translation_languages)) .route( "/modes", - axum::routing::get(|| async { + get(|| async { axum::Json([ TTSMode::gTTS.to_string(), TTSMode::Polly.to_string(), diff --git a/src/translation.rs b/src/translation.rs index f981444..902ab3a 100644 --- a/src/translation.rs +++ b/src/translation.rs @@ -1,6 +1,7 @@ use std::marker::PhantomData; use anyhow::Result; +use serde::ser::SerializeStruct; use small_fixed_array::FixedString; fn deserialize_single_seq<'de, T, D>(deserializer: D) -> Result, D::Error> @@ -47,9 +48,13 @@ struct TranslateResponse { pub translations: Option, } +fn auth_header(token: &str) -> String { + format!("DeepL-Auth-Key {token}") +} + pub async fn run( reqwest: &reqwest::Client, - translation_token: &str, + token: &str, content: &str, target_lang: &str, ) -> Result> { @@ -59,11 +64,10 @@ pub async fn run( preserve_formatting: 1, }; - let auth_header = format!("DeepL-Auth-Key {translation_token}"); let response: TranslateResponse = reqwest .get("https://api.deepl.com/v2/translate") .query(&request) - .header("Authorization", auth_header) + .header("Authorization", auth_header(token)) .send() .await? .error_for_status()? @@ -78,3 +82,44 @@ pub async fn run( Ok(None) } + +#[derive(serde::Deserialize)] +struct Voice { + pub name: FixedString, + pub language: FixedString, +} + +struct VoiceRequest; +impl serde::Serialize for VoiceRequest { + fn serialize(&self, serializer: S) -> std::prelude::v1::Result + where + S: serde::Serializer, + { + let mut serializer = serializer.serialize_struct("DeeplVoiceRequest", 1)?; + serializer.serialize_field("type", "target")?; + serializer.end() + } +} + +pub async fn get_languages( + reqwest: &reqwest::Client, + token: &str, +) -> Result> { + let languages: Vec = reqwest + .get("https://api.deepl.com/v2/languages") + .query(&VoiceRequest) + .header("Authorization", auth_header(token)) + .send() + .await? + .error_for_status()? + .json() + .await?; + + let language_map = languages + .into_iter() + .map(|v| (v.language, v.name)) + .collect(); + + println!("Loaded DeepL translation languages"); + Ok(language_map) +}