Skip to content

Commit

Permalink
feat(rust): Added conversion traits for AnyValues and Series to Datums
Browse files Browse the repository at this point in the history
  • Loading branch information
schmidmt committed May 24, 2024
1 parent 83fd246 commit 62178a0
Showing 1 changed file with 272 additions and 0 deletions.
272 changes: 272 additions & 0 deletions lace/lace_codebook/src/data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,288 @@ use crate::{
};

use lace_consts::rv::prelude::UnitPowerLaw;
use lace_data::{Category, Datum};
use lace_stats::prior::csd::CsdHyper;
use lace_stats::prior::nix::NixHyper;
use lace_stats::prior::pg::PgHyper;
use lace_stats::prior::sbd::SbdHyper;
use polars::datatypes::AnyValue;
use polars::prelude::{CsvReader, DataFrame, DataType, SerReader, Series};
use std::convert::TryFrom;
use std::path::Path;
use thiserror::Error;

pub const DEFAULT_CAT_CUTOFF: u8 = 20;

/// An Error from converting a Polar's AnyValue to a Datum
#[derive(Debug, Error)]
pub enum ConversionError {
#[error("The given value `{0}` is not an existing category.")]
ValueNotACategory(String),
#[error("The category is not indexed by a string.")]
CategoryNotIndexedByString,
#[error("The type `{0}` is not supported by this conversion.")]
UnsupportedType(String),
#[error("The category count is out of the existing bounds.")]
CountOutOfBounds,
#[error("The index is out of the existing bounds.")]
IndexOutOfBounds,
}

pub trait AnyValueDatumExt: Sized {
/// Convert a AnyValue to a Datum with a specific coltype.
///
/// * `coltype` - Column type to convert datum into.
/// * `drop_out_of_category` - Set output to `Datum::Missing` if given value is not contained
/// in the existing category.
fn to_datum(
self,
coltype: &ColType,
drop_out_of_category: bool,
) -> Result<Datum, ConversionError>;
}

macro_rules! int_to_category {
($x: expr, $value_map: expr, $out_to_missing: expr) => {{
let idx: u8 = $x.try_into().map_err(|_| {
ConversionError::ValueNotACategory(format!("{}", $x))
})?;

if let ValueMap::U8(size) = $value_map {
if (idx as usize) < *size {
Ok(Datum::Categorical(Category::U8(idx)))
} else if $out_to_missing {
Ok(Datum::Missing)
} else {
Err(ConversionError::ValueNotACategory(idx.to_string()))
}
} else {
Err(ConversionError::ValueNotACategory(idx.to_string()))
}
}};
}

impl<'a> AnyValueDatumExt for AnyValue<'a> {
fn to_datum(
self,
coltype: &ColType,
drop_out_of_category: bool,
) -> Result<Datum, ConversionError> {
match (self, coltype) {
(AnyValue::Null, _) => Ok(Datum::Missing),
(AnyValue::String(s), ColType::Categorical { value_map, .. }) => {
if let ValueMap::String(cat_map) = value_map {
if let Some(_cat_idx) = cat_map.ix(s) {
Ok(Datum::Categorical(lace_data::Category::String(
s.to_string(),
)))
} else {
if drop_out_of_category {
Ok(Datum::Missing)
} else {
Err(ConversionError::ValueNotACategory(
s.to_string(),
))
}
}
} else {
Err(ConversionError::CategoryNotIndexedByString)
}
}

(
AnyValue::StringOwned(s),
ColType::Categorical { value_map, .. },
) => {
if let ValueMap::String(cat_map) = value_map {
if let Some(_cat_idx) = cat_map.ix(&s.to_string()) {
Ok(Datum::Categorical(lace_data::Category::String(
s.to_string(),
)))
} else {
if drop_out_of_category {
Ok(Datum::Missing)
} else {
Err(ConversionError::ValueNotACategory(
s.to_string(),
))
}
}
} else {
Err(ConversionError::CategoryNotIndexedByString)
}
}

(AnyValue::Boolean(b), ColType::Categorical { value_map, .. }) => {
if let ValueMap::Bool = value_map {
Ok(Datum::Binary(b))
} else {
Err(ConversionError::ValueNotACategory(b.to_string()))
}
}
(AnyValue::UInt8(x), ColType::Categorical { value_map, .. }) => {
int_to_category!(x, value_map, drop_out_of_category)
}
(AnyValue::UInt8(x), ColType::Count { .. }) => Ok(Datum::Count(
x.try_into()
.map_err(|_| ConversionError::CountOutOfBounds)?,
)),
(AnyValue::UInt8(x), ColType::StickBreakingDiscrete { .. }) => {
Ok(Datum::Index(x.into()))
}
(AnyValue::UInt16(x), ColType::Categorical { value_map, .. }) => {
int_to_category!(x, value_map, drop_out_of_category)
}
(AnyValue::UInt16(x), ColType::Count { .. }) => Ok(Datum::Count(
x.try_into()
.map_err(|_| ConversionError::CountOutOfBounds)?,
)),
(AnyValue::UInt16(x), ColType::StickBreakingDiscrete { .. }) => {
Ok(Datum::Index(x.into()))
}
(AnyValue::UInt32(x), ColType::Categorical { value_map, .. }) => {
int_to_category!(x, value_map, drop_out_of_category)
}
(AnyValue::UInt32(x), ColType::Count { .. }) => Ok(Datum::Count(
x.try_into()
.map_err(|_| ConversionError::CountOutOfBounds)?,
)),
(AnyValue::UInt32(x), ColType::StickBreakingDiscrete { .. }) => {
Ok(Datum::Index(
x.try_into()
.map_err(|_| ConversionError::IndexOutOfBounds)?,
))
}
(AnyValue::UInt64(x), ColType::Categorical { value_map, .. }) => {
int_to_category!(x, value_map, drop_out_of_category)
}
(AnyValue::UInt64(x), ColType::Count { .. }) => Ok(Datum::Count(
x.try_into()
.map_err(|_| ConversionError::CountOutOfBounds)?,
)),
(AnyValue::UInt64(x), ColType::StickBreakingDiscrete { .. }) => {
Ok(Datum::Index(
x.try_into()
.map_err(|_| ConversionError::IndexOutOfBounds)?,
))
}

(AnyValue::Int8(x), ColType::Categorical { value_map, .. }) => {
int_to_category!(x, value_map, drop_out_of_category)
}
(AnyValue::Int8(x), ColType::Count { .. }) => Ok(Datum::Count(
x.try_into()
.map_err(|_| ConversionError::CountOutOfBounds)?,
)),
(AnyValue::Int8(x), ColType::StickBreakingDiscrete { .. }) => {
Ok(Datum::Index(
x.try_into()
.map_err(|_| ConversionError::IndexOutOfBounds)?,
))
}
(AnyValue::Int16(x), ColType::Categorical { value_map, .. }) => {
int_to_category!(x, value_map, drop_out_of_category)
}
(AnyValue::Int16(x), ColType::Count { .. }) => Ok(Datum::Count(
x.try_into()
.map_err(|_| ConversionError::CountOutOfBounds)?,
)),
(AnyValue::Int16(x), ColType::StickBreakingDiscrete { .. }) => {
Ok(Datum::Index(
x.try_into()
.map_err(|_| ConversionError::IndexOutOfBounds)?,
))
}
(AnyValue::Int32(x), ColType::Categorical { value_map, .. }) => {
int_to_category!(x, value_map, drop_out_of_category)
}
(AnyValue::Int32(x), ColType::Count { .. }) => Ok(Datum::Count(
x.try_into()
.map_err(|_| ConversionError::CountOutOfBounds)?,
)),
(AnyValue::Int32(x), ColType::StickBreakingDiscrete { .. }) => {
Ok(Datum::Index(
x.try_into()
.map_err(|_| ConversionError::IndexOutOfBounds)?,
))
}
(AnyValue::Int64(x), ColType::Categorical { value_map, .. }) => {
int_to_category!(x, value_map, drop_out_of_category)
}
(AnyValue::Int64(x), ColType::Count { .. }) => Ok(Datum::Count(
x.try_into()
.map_err(|_| ConversionError::CountOutOfBounds)?,
)),
(AnyValue::Int64(x), ColType::StickBreakingDiscrete { .. }) => {
Ok(Datum::Index(
x.try_into()
.map_err(|_| ConversionError::IndexOutOfBounds)?,
))
}

(AnyValue::Float32(x), ColType::Continuous { .. }) => {
Ok(Datum::Continuous(x.into()))
}
(AnyValue::Float64(x), ColType::Continuous { .. }) => {
Ok(Datum::Continuous(x))
}

(av, _) => Err(ConversionError::UnsupportedType(av.to_string())),
}
}
}

/// Series to collection of `Datum` conversion helper.
pub trait SeriesDatumExt {
/// Convert a `polars::Series` to a `Vec<Datum>` with a specific coltype.
///
/// * `coltype` - Column type to convert datum into.
/// * `drop_out_of_category` - Set output to `Datum::Missing` if given value is not contained
/// in the existing category.
fn to_datum_vec(
self,
col_type: &ColType,
drop_out_of_category: bool,
) -> Result<Vec<Datum>, ConversionError>;

/// Convert a `polars::Series` to an iterator of `Datum`s.
///
/// * `coltype` - Column type to convert datum into.
/// * `drop_out_of_category` - Set output to `Datum::Missing` if given value is not contained
/// in the existing category.
fn as_datum_iter(
&self,
col_type: &ColType,
drop_out_of_category: bool,
) -> impl Iterator<Item = Result<Datum, ConversionError>>;
}

impl SeriesDatumExt for Series {
fn to_datum_vec(
self,
col_type: &ColType,
drop_out_of_category: bool,
) -> Result<Vec<Datum>, ConversionError> {
// XXX: Rechunk is only required because of a polar's design oddity, remove this if polars
// fixes it.
let arr = self.rechunk();
arr.iter()
.map(|x: AnyValue| x.to_datum(col_type, drop_out_of_category))
.collect::<Result<Vec<Datum>, _>>()
}

fn as_datum_iter(
&self,
col_type: &ColType,
drop_out_of_category: bool,
) -> impl Iterator<Item = Result<Datum, ConversionError>> {
// XXX: Rechunk is only required because of a polar's design oddity, remove this if polars
// fixes it.
self.iter()
.map(move |x: AnyValue| x.to_datum(col_type, drop_out_of_category))
}
}

#[macro_export]
macro_rules! series_to_opt_vec {
($srs: ident, $X: ty) => {{
Expand Down

0 comments on commit 62178a0

Please sign in to comment.