Skip to content

Commit

Permalink
Add translation languages endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
GnomedDev committed Mar 5, 2024
1 parent e128c22 commit 15a07b9
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 7 deletions.
21 changes: 17 additions & 4 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

use std::{fmt::Display, str::FromStr, sync::OnceLock};

use axum::{http::header::HeaderValue, response::Response};
use axum::{http::header::HeaderValue, response::Response, routing::get, Json};
use bytes::Bytes;
use deadpool_redis::redis::AsyncCommands;
use serde_json::to_value;
Expand Down Expand Up @@ -62,6 +62,18 @@ async fn get_voices(
}))
}

async fn get_translation_languages() -> ResponseResult<Json<Vec<(FixedString, FixedString)>>> {
let state = STATE.get().unwrap();
let Some(token) = &state.translation_key else {
return Ok(Json(Vec::new()));
};

match translation::get_languages(&state.reqwest, token).await {
Ok(languages) => Ok(Json(languages)),
Err(err) => Err(Error::Unknown(err)),
}
}

#[derive(serde::Deserialize)]
struct GetTTS {
text: FixedString,
Expand Down Expand Up @@ -342,11 +354,12 @@ async fn main() -> Result<()> {
}

let app = axum::Router::new()
.route("/tts", axum::routing::get(get_tts))
.route("/voices", axum::routing::get(get_voices))
.route("/tts", get(get_tts))
.route("/voices", get(get_voices))
.route("/translation_languages", get(get_translation_languages))
.route(
"/modes",
axum::routing::get(|| async {
get(|| async {
axum::Json([
TTSMode::gTTS.to_string(),
TTSMode::Polly.to_string(),
Expand Down
51 changes: 48 additions & 3 deletions src/translation.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::marker::PhantomData;

use anyhow::Result;
use serde::ser::SerializeStruct;
use small_fixed_array::FixedString;

fn deserialize_single_seq<'de, T, D>(deserializer: D) -> Result<Option<T>, D::Error>
Expand Down Expand Up @@ -47,9 +48,13 @@ struct TranslateResponse {
pub translations: Option<Translation>,
}

fn auth_header(token: &str) -> String {
format!("DeepL-Auth-Key {token}")
}

pub async fn run(
reqwest: &reqwest::Client,
translation_token: &str,
token: &str,
content: &str,
target_lang: &str,
) -> Result<Option<FixedString>> {
Expand All @@ -59,11 +64,10 @@ pub async fn run(
preserve_formatting: 1,
};

let auth_header = format!("DeepL-Auth-Key {translation_token}");
let response: TranslateResponse = reqwest
.get("https://api.deepl.com/v2/translate")
.query(&request)
.header("Authorization", auth_header)
.header("Authorization", auth_header(token))
.send()
.await?
.error_for_status()?
Expand All @@ -78,3 +82,44 @@ pub async fn run(

Ok(None)
}

#[derive(serde::Deserialize)]
struct Voice {
pub name: FixedString,
pub language: FixedString,
}

struct VoiceRequest;
impl serde::Serialize for VoiceRequest {
fn serialize<S>(&self, serializer: S) -> std::prelude::v1::Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let mut serializer = serializer.serialize_struct("DeeplVoiceRequest", 1)?;
serializer.serialize_field("type", "target")?;
serializer.end()
}
}

pub async fn get_languages(
reqwest: &reqwest::Client,
token: &str,
) -> Result<Vec<(FixedString, FixedString)>> {
let languages: Vec<Voice> = reqwest
.get("https://api.deepl.com/v2/languages")
.query(&VoiceRequest)
.header("Authorization", auth_header(token))
.send()
.await?
.error_for_status()?
.json()
.await?;

let language_map = languages
.into_iter()
.map(|v| (v.language, v.name))
.collect();

println!("Loaded DeepL translation languages");
Ok(language_map)
}

0 comments on commit 15a07b9

Please sign in to comment.