From 5e0b0ce31df1d39f4e182c140b906ccf202f7f7f Mon Sep 17 00:00:00 2001 From: Joel Natividad <1980690+jqnatividad@users.noreply.github.com> Date: Sun, 9 Feb 2025 16:08:41 -0500 Subject: [PATCH 1/2] feat: `sample` add four more sampling methods - systematic, stratified, weighted & cluster --- src/cmd/sample.rs | 769 +++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 689 insertions(+), 80 deletions(-) diff --git a/src/cmd/sample.rs b/src/cmd/sample.rs index f4d4ae587..3b8b9eab2 100644 --- a/src/cmd/sample.rs +++ b/src/cmd/sample.rs @@ -1,24 +1,57 @@ static USAGE: &str = r#" Randomly samples CSV data. -It supports three sampling methods: -- INDEXED: the default sampling method when an index is present. - Uses random I/O to sample efficiently, as it only visits records selected - by random indexing, using CONSTANT memory proportional to the . - The number of records in the output is exactly equal to the . - +It supports seven sampling methods: - RESERVOIR: the default sampling method when NO INDEX is present. - Visits every CSV record exactly once, using memory proportional to . - The number of records in the output is exactly equal to the . + Visits every CSV record exactly once, using MEMORY PROPORTIONAL to the + sample size (k) - O(k). https://en.wikipedia.org/wiki/Reservoir_sampling +- INDEXED: the default sampling method when an index is present. + Uses random I/O to sample efficiently, as it only visits records selected + by random indexing, using MEMORY PROPORTIONAL to the sample size (k) - O(k). + https://en.wikipedia.org/wiki/Random_access + - BERNOULLI: the sampling method when the --bernoulli option is specified. Visits every CSV record exactly once and selects records with a given probability - as specified by the argument. It uses constant memory. - The number of records in the output follows a binomial distribution with - parameters n (input size) and p (sample-size as probability). + as specified by the argument. The number of records in the output + follows a binomial distribution with parameters n (input size) and + p (sample-size as probability). Uses CONSTANT memory - O(1). https://en.wikipedia.org/wiki/Bernoulli_sampling +- SYSTEMATIC: the sampling method when the --systematic option is specified. + Selects every nth record from the input, where n = population_size/sample_size. + The sample size must be a whole number. Uses CONSTANT memory - O(1). + The starting point can be specified as "random" or "first". + Useful for time series data or when you want evenly spaced samples. + https://en.wikipedia.org/wiki/Systematic_sampling + +- STRATIFIED: the sampling method when the --stratified option is specified. + Stratifies the population by the specified column and then samples from each stratum. + Particularly useful when a population has distinct subgroups (strata) that are + heterogeneous within but homogeneous between in terms of the variable of interest. + For example, if you want to sample 1000 records from a population of 100,000, + you can stratify the population by gender and then sample 500 records from each + stratum. This will ensure that you have a representative sample from each gender. + The sample size must be a whole number. Uses MEMORY PROPORTIONAL to the + number of strata (s) and samples per stratum (k) - O(s*k). + https://en.wikipedia.org/wiki/Stratified_sampling + +- WEIGHTED: the sampling method when the --weighted option is specified. + Samples records with probability proportional to weights in the specified weight column. + Useful when some records are more important than others. + Uses MEMORY PROPORTIONAL to the sample size (k) - O(k). + "Weighted random sampling with a reservoir" https://doi.org/10.1016/j.ipl.2005.11.003 + +- CLUSTER: the sampling method when the --cluster option is specified. + Samples entire groups of records together based on a cluster identifier column. + Useful when records are naturally grouped (e.g., by household, neighborhood, etc.). + For example, if you want to sample 1000 records from a population of 100,000, + you can cluster the population by neighborhood and then sample 100 records from each + cluster. This will ensure that you have a representative sample from each neighborhood. + Uses MEMORY PROPORTIONAL to the number of clusters (c) - O(c). + https://en.wikipedia.org/wiki/Cluster_sampling + Supports sampling from CSVs on remote URLs. This command is intended to provide a means to sample from a CSV data set that @@ -35,7 +68,10 @@ sample arguments: The CSV file to sample. This can be a local file, stdin, or a URL (http and https schemes supported). - When using INDEXED or RESERVOIR sampling, the number of records to sample. + When using INDEXED, RESERVOIR or WEIGHTED sampling, the number of records to sample. + When using SYSTEMATIC sampling, the interval between records to sample. + When using STRATIFIED sampling, the number of records to sample per stratum. + When using CLUSTER sampling, the number of records to sample per cluster. When using BERNOULLI sampling, the probability of selecting each record (between 0 and 1). @@ -51,9 +87,25 @@ sample options: Recommended by eSTREAM (https://www.ecrypt.eu.org/stream/). 2.1 GB/s throughput though slow initialization. [default: standard] + + SAMPLING METHODS: --bernoulli Use Bernoulli sampling instead of indexed or reservoir sampling. When this flag is set, the sample-size must be between 0 and 1 and represents the probability of selecting each record. + --systematic Use systematic sampling (every nth record as specified by sample-size). + If is "random", the starting point is randomly chosen between 0 & n. + If is "first", the starting point is the first record. + The sample size must be a whole number. Uses CONSTANT memory - O(1). + --stratified Use stratified sampling. The strata column is specified by . + Can be either a column name or 0-based column index. + The sample size must be a whole number. Uses MEMORY PROPORTIONAL to the + number of strata (s) and samples per stratum (k) - O(s*k). + --weighted Use weighted sampling. The weight column is specified by . + Can be either a column name or 0-based column index. + Uses MEMORY PROPORTIONAL to the sample size (k) - O(k). + --cluster Use cluster sampling. The cluster column is specified by . + Can be either a column name or 0-based column index. + Uses MEMORY PROPORTIONAL to the number of clusters (c) - O(c). REMOTE FILE OPTIONS: --user-agent Specify custom user agent to use when the input is a URL. @@ -79,15 +131,21 @@ Common options: Must be a single character. (default: ,) "#; -use std::{io, str::FromStr}; +use std::{ + collections::{HashMap, HashSet}, + io, + str::FromStr, +}; use rand::{ distr::{Bernoulli, Distribution}, + prelude::IndexedRandom, rngs::StdRng, Rng, SeedableRng, }; use rand_hc::Hc128Rng; use rand_xoshiro::Xoshiro256Plus; +use rayon::prelude::ParallelSliceMut; use serde::Deserialize; use strum_macros::EnumString; use tempfile::NamedTempFile; @@ -111,6 +169,70 @@ struct Args { flag_timeout: Option, flag_max_size: Option, flag_bernoulli: bool, + flag_systematic: Option, + flag_stratified: Option, + flag_weighted: Option, + flag_cluster: Option, +} + +impl Args { + fn get_column_index( + #[allow(clippy::unused_self)] &self, + header: &csv::ByteRecord, + column_spec: &str, + purpose: &str, + ) -> CliResult { + // Try parsing as number first + if let Ok(idx) = column_spec.parse::() { + if idx < header.len() { + return Ok(idx); + } + return fail_incorrectusage_clierror!( + "{} column index {} is out of bounds (max: {})", + purpose, + idx, + header.len() - 1 + ); + } + + // If not a number, try to find column by name + for (i, field) in header.iter().enumerate() { + if column_spec == String::from_utf8_lossy(field) { + return Ok(i); + } + } + + fail_incorrectusage_clierror!("Could not find {} column named '{}'", purpose, column_spec) + } + + fn get_strata_column(&self, header: &csv::ByteRecord) -> CliResult { + match &self.flag_stratified { + Some(col) => self.get_column_index(header, col, "strata"), + None => { + fail_incorrectusage_clierror!( + "--stratified is required for stratified sampling" + ) + }, + } + } + + fn get_weight_column(&self, header: &csv::ByteRecord) -> CliResult { + match &self.flag_weighted { + Some(col) => self.get_column_index(header, col, "weight"), + None => { + fail_incorrectusage_clierror!("--weighted is required for weighted sampling") + }, + } + } + + fn get_cluster_column(&self, header: &csv::ByteRecord) -> CliResult { + match &self.flag_cluster { + Some(col) => self.get_column_index(header, col, "cluster"), + None => { + fail_incorrectusage_clierror!("--cluster is required for cluster sampling") + }, + } + } } #[derive(Debug, EnumString, PartialEq)] @@ -121,6 +243,15 @@ enum RngKind { Cryptosecure, } +enum SamplingMethod { + Bernoulli, + Systematic, + Stratified, + Weighted, + Cluster, + Default, +} + // trait to handle different RNG types trait RngProvider: Sized { type RngType: Rng + SeedableRng; @@ -167,6 +298,18 @@ impl RngProvider for CryptoRng { pub fn run(argv: &[&str]) -> CliResult<()> { let mut args: Args = util::get_args(USAGE, argv)?; + // Validate that only one sampling method is selected + let methods = [ + args.flag_bernoulli, + args.flag_systematic.is_some(), + args.flag_stratified.is_some(), + args.flag_weighted.is_some(), + args.flag_cluster.is_some(), + ]; + if methods.iter().filter(|&&x| x).count() > 1 { + return fail_incorrectusage_clierror!("Only one sampling method can be specified"); + } + let Ok(rng_kind) = RngKind::from_str(&args.flag_rng) else { return fail_incorrectusage_clierror!( "Invalid RNG algorithm `{}`. Supported RNGs are: standard, faster, cryptosecure.", @@ -174,8 +317,25 @@ pub fn run(argv: &[&str]) -> CliResult<()> { ); }; + let sampling_method = match ( + args.flag_bernoulli, + args.flag_systematic.is_some(), + args.flag_stratified.is_some(), + args.flag_weighted.is_some(), + args.flag_cluster.is_some(), + ) { + (true, _, _, _, _) => SamplingMethod::Bernoulli, + (_, true, _, _, _) => SamplingMethod::Systematic, + (_, _, true, _, _) => SamplingMethod::Stratified, + (_, _, _, true, _) => SamplingMethod::Weighted, + (_, _, _, _, true) => SamplingMethod::Cluster, + (false, false, false, false, false) => SamplingMethod::Default, + }; + let temp_download = NamedTempFile::new()?; + // Clone the user_agent before using it + let user_agent = args.flag_user_agent.clone(); args.arg_input = match args.arg_input { Some(uri) if Url::parse(&uri).is_ok() && uri.starts_with("http") => { let max_size_bytes = args.flag_max_size.map(|mb| mb * 1024 * 1024); @@ -185,7 +345,7 @@ pub fn run(argv: &[&str]) -> CliResult<()> { &uri, temp_download.path().to_path_buf(), false, - args.flag_user_agent, + user_agent, args.flag_timeout, max_size_bytes, ); @@ -203,79 +363,147 @@ pub fn run(argv: &[&str]) -> CliResult<()> { .flexible(true) .skip_format_check(true); - let mut sample_size = args.arg_sample_size; - + let mut rdr = rconfig.reader()?; let mut wtr = Config::new(args.flag_output.as_ref()) .delimiter(args.flag_delimiter) .writer()?; - if args.flag_bernoulli { - if sample_size >= 1.0 || sample_size <= 0.0 { - return fail_incorrectusage_clierror!( - "Bernoulli sampling requires a probability between 0 and 1" - ); - } + // Write headers unless --no-headers is specified + rconfig.write_headers(&mut rdr, &mut wtr)?; - let mut rdr = rconfig.reader()?; - rconfig.write_headers(&mut rdr, &mut wtr)?; - sample_bernoulli(&mut rdr, &mut wtr, sample_size, args.flag_seed, &rng_kind)?; - } else if let Some(mut idx) = rconfig.indexed()? { - // an index is present, so use random indexing - #[allow(clippy::cast_precision_loss)] - if sample_size < 1.0 { - sample_size *= idx.count() as f64; - } - rconfig.write_headers(&mut *idx, &mut wtr)?; + let mut sample_size = args.arg_sample_size; - let sample_count = sample_size as usize; - let total_count = idx.count().try_into().unwrap(); + match sampling_method { + SamplingMethod::Bernoulli => { + if args.arg_sample_size >= 1.0 || args.arg_sample_size <= 0.0 { + return fail_incorrectusage_clierror!( + "Bernoulli sampling requires a probability between 0 and 1" + ); + } - match rng_kind { - RngKind::Standard => { - log::info!("doing standard INDEXED sampling..."); - let mut rng = StandardRng::create(args.flag_seed); - sample_indices(&mut rng, total_count, sample_count, |i| { - idx.seek(i as u64)?; - Ok(wtr.write_byte_record(&idx.byte_records().next().unwrap()?)?) - })?; - }, - RngKind::Faster => { - log::info!("doing --faster INDEXED sampling..."); - let mut rng = FasterRng::create(args.flag_seed); - sample_indices(&mut rng, total_count, sample_count, |i| { - idx.seek(i as u64)?; - Ok(wtr.write_byte_record(&idx.byte_records().next().unwrap()?)?) - })?; - }, - RngKind::Cryptosecure => { - log::info!("doing --cryptosecure INDEXED sampling..."); - let mut rng = CryptoRng::create(args.flag_seed); - sample_indices(&mut rng, total_count, sample_count, |i| { - idx.seek(i as u64)?; - Ok(wtr.write_byte_record(&idx.byte_records().next().unwrap()?)?) - })?; - }, - } - } else { - // bernoulli sampling is not specified nor is an index present - // so we do reservoir sampling - #[allow(clippy::cast_precision_loss)] - if sample_size < 1.0 { - let Ok(row_count) = util::count_rows(&rconfig) else { - return fail!("Cannot get rowcount. Percentage sampling requires a rowcount."); + sample_bernoulli( + &mut rdr, + &mut wtr, + args.arg_sample_size, + args.flag_seed, + &rng_kind, + )?; + }, + SamplingMethod::Systematic => { + let starting_point = match args.flag_systematic.as_deref().map(str::to_lowercase) { + Some(arg) if arg == "random" || arg == "first" => arg, + Some(_) => { + return fail_incorrectusage_clierror!( + "Systematic sampling starting point must be either 'random' or 'first'" + ) + }, + None => String::from("random"), }; - sample_size *= row_count as f64; - } - let mut rdr = rconfig.reader()?; - rconfig.write_headers(&mut rdr, &mut wtr)?; + let row_count = util::count_rows(&rconfig)?; + sample_systematic( + &mut rdr, + &mut wtr, + args.arg_sample_size, + row_count, + &starting_point, + args.flag_seed, + &rng_kind, + )?; + }, + SamplingMethod::Stratified => { + let strata_column = args.get_strata_column(&rdr.byte_headers()?.clone())?; + sample_stratified( + &mut rdr, + &mut wtr, + strata_column, + args.arg_sample_size as usize, + args.flag_seed, + &rng_kind, + )?; + }, + SamplingMethod::Weighted => { + let weight_column = args.get_weight_column(&rdr.byte_headers()?.clone())?; + sample_weighted( + &rconfig, + &mut rdr, + &mut wtr, + weight_column, + args.arg_sample_size as usize, + args.flag_seed, + &rng_kind, + )?; + }, + SamplingMethod::Cluster => { + let cluster_column = args.get_cluster_column(&rdr.byte_headers()?.clone())?; + sample_cluster( + &rconfig, + &mut rdr, + &mut wtr, + cluster_column, + args.arg_sample_size as usize, + args.flag_seed, + &rng_kind, + )?; + }, + SamplingMethod::Default => { + // no sampling method is specified, so we do indexed sampling + // if an index is present + if let Some(mut idx) = rconfig.indexed()? { + #[allow(clippy::cast_precision_loss)] + if sample_size < 1.0 { + sample_size *= idx.count() as f64; + } + + let sample_count = sample_size as usize; + let total_count = idx.count().try_into().unwrap(); - sample_reservoir( - &mut rdr, - &mut wtr, - sample_size as u64, - args.flag_seed, - &rng_kind, - )?; + match rng_kind { + RngKind::Standard => { + log::info!("doing standard INDEXED sampling..."); + let mut rng = StandardRng::create(args.flag_seed); + sample_indices(&mut rng, total_count, sample_count, |i| { + idx.seek(i as u64)?; + Ok(wtr.write_byte_record(&idx.byte_records().next().unwrap()?)?) + })?; + }, + RngKind::Faster => { + log::info!("doing --faster INDEXED sampling..."); + let mut rng = FasterRng::create(args.flag_seed); + sample_indices(&mut rng, total_count, sample_count, |i| { + idx.seek(i as u64)?; + Ok(wtr.write_byte_record(&idx.byte_records().next().unwrap()?)?) + })?; + }, + RngKind::Cryptosecure => { + log::info!("doing --cryptosecure INDEXED sampling..."); + let mut rng = CryptoRng::create(args.flag_seed); + sample_indices(&mut rng, total_count, sample_count, |i| { + idx.seek(i as u64)?; + Ok(wtr.write_byte_record(&idx.byte_records().next().unwrap()?)?) + })?; + }, + } + } else { + // No sampling method is specified and no index is present + // do reservoir sampling + #[allow(clippy::cast_precision_loss)] + if args.arg_sample_size < 1.0 { + let Ok(row_count) = util::count_rows(&rconfig) else { + return fail!( + "Cannot get rowcount. Percentage sampling requires a rowcount." + ); + }; + args.arg_sample_size *= row_count as f64; + } + sample_reservoir( + &mut rdr, + &mut wtr, + args.arg_sample_size as u64, + args.flag_seed, + &rng_kind, + )?; + } + }, } Ok(wtr.flush()?) @@ -392,8 +620,6 @@ fn sample_indices( where F: FnMut(usize) -> CliResult<()>, { - use rayon::prelude::ParallelSliceMut; - if sample_count > total_count { return fail!("Sample size cannot be larger than population size"); } @@ -422,3 +648,386 @@ where Ok(()) } + +// Systematic sampling implementation +fn sample_systematic( + rdr: &mut csv::Reader, + wtr: &mut csv::Writer, + sample_size: f64, + row_count: u64, + starting_point: &str, + seed: Option, + rng_kind: &RngKind, +) -> CliResult<()> { + if sample_size <= 0.0 { + return fail_incorrectusage_clierror!("Sample size must be positive"); + } + + let sample_size = sample_size.round() as usize; + if sample_size as u64 > row_count { + return fail_incorrectusage_clierror!("Sample size cannot be larger than population size"); + } + + // Select starting point + let start = if starting_point == "random" { + match rng_kind { + RngKind::Standard => { + let mut rng = StandardRng::create(seed); + rng.random_range(0..sample_size) + }, + RngKind::Faster => { + let mut rng = FasterRng::create(seed); + rng.random_range(0..sample_size) + }, + RngKind::Cryptosecure => { + let mut rng = CryptoRng::create(seed); + rng.random_range(0..sample_size) + }, + } + } else { + 0 // starting point is the first record + }; + + // Select records at regular intervals + for (i, record) in rdr.byte_records().enumerate().skip(start) { + if i % sample_size == 0 && (i / sample_size) < sample_size { + wtr.write_byte_record(&record?)?; + } + } + Ok(()) +} + +// Stratified sampling implementation +fn sample_stratified( + rdr: &mut csv::Reader, + wtr: &mut csv::Writer, + strata_column: usize, + samples_per_stratum: usize, + seed: Option, + rng_kind: &RngKind, +) -> CliResult<()> { + const ESTIMATED_STRATA_COUNT: usize = 100; + + // Pre-allocate with capacity for better performance + let mut strata_counts: HashMap, usize> = HashMap::with_capacity(ESTIMATED_STRATA_COUNT); + let mut records = Vec::with_capacity(ESTIMATED_STRATA_COUNT * samples_per_stratum); + + // First pass: count strata and collect records + for record in rdr.byte_records() { + let record = record?; + let stratum = record + .get(strata_column) + .ok_or_else(|| format!("Strata column index {strata_column} out of bounds"))? + .to_vec(); + *strata_counts.entry(stratum.clone()).or_default() += 1; + records.push(record); + } + + let strata_count = strata_counts.len(); + if strata_count == 0 { + return fail_incorrectusage_clierror!("No valid strata found in the data"); + } + + // Initialize reservoirs with capacity + let mut reservoirs: HashMap, Vec> = + HashMap::with_capacity(strata_count); + for stratum in strata_counts.keys() { + reservoirs.insert(stratum.clone(), Vec::with_capacity(samples_per_stratum)); + } + + // Create RNG and perform sampling + match rng_kind { + RngKind::Standard => { + let mut rng = StandardRng::create(seed); + do_stratified_sampling( + records.into_iter(), + &mut reservoirs, + strata_column, + samples_per_stratum, + &mut rng, + )?; + }, + RngKind::Faster => { + let mut rng = FasterRng::create(seed); + do_stratified_sampling( + records.into_iter(), + &mut reservoirs, + strata_column, + samples_per_stratum, + &mut rng, + )?; + }, + RngKind::Cryptosecure => { + let mut rng = CryptoRng::create(seed); + do_stratified_sampling( + records.into_iter(), + &mut reservoirs, + strata_column, + samples_per_stratum, + &mut rng, + )?; + }, + } + + // Write results in deterministic order + let mut strata: Vec<_> = reservoirs.keys().collect(); + strata.par_sort_unstable(); + for stratum in strata { + if let Some(records) = reservoirs.get(stratum) { + for record in records { + wtr.write_byte_record(record)?; + } + } + } + + Ok(()) +} + +fn do_stratified_sampling( + records: impl Iterator, + reservoirs: &mut HashMap, Vec>, + strata_column: usize, + samples_per_stratum: usize, + rng: &mut T, +) -> CliResult<()> { + let mut records_seen: HashMap, usize> = HashMap::with_capacity(reservoirs.len()); + + for record in records { + let stratum = record + .get(strata_column) + .ok_or_else(|| format!("Strata column index {strata_column} out of bounds"))? + .to_vec(); + + let seen = records_seen.entry(stratum.clone()).or_default(); + + if let Some(reservoir) = reservoirs.get_mut(&stratum) { + if reservoir.len() < samples_per_stratum { + reservoir.push(record); + } else { + let j = rng.random_range(0..=*seen); + if j < samples_per_stratum { + // safety: we know that j is within the bounds of the reservoir + unsafe { *reservoir.get_unchecked_mut(j) = record }; + } + } + *seen += 1; + } + } + Ok(()) +} + +// Weighted sampling implementation +fn sample_weighted( + rconfig: &Config, + rdr: &mut csv::Reader, + wtr: &mut csv::Writer, + weight_column: usize, + sample_size: usize, + seed: Option, + rng_kind: &RngKind, +) -> CliResult<()> { + // First pass: find maximum weight + let mut max_weight = 0.0f64; + for record in rdr.byte_records() { + let record = record?; + let weight = String::from_utf8_lossy( + record + .get(weight_column) + .ok_or_else(|| format!("Weight column index {weight_column} out of bounds"))?, + ) + .parse::() + .map_err(|_| "Invalid weight value")?; + + if weight < 0.0 { + return fail_incorrectusage_clierror!("Weights must be non-negative"); + } + max_weight = max_weight.max(weight); + } + + if max_weight == 0.0 { + return fail_incorrectusage_clierror!("All weights are zero"); + } + + // Second pass: acceptance-rejection sampling + let mut rdr2 = rconfig.reader()?; + + match rng_kind { + RngKind::Standard => { + log::info!("doing standard WEIGHTED sampling..."); + let mut rng = StandardRng::create(seed); + do_weighted_sampling( + &mut rdr2.byte_records(), + wtr, + weight_column, + sample_size, + max_weight, + &mut rng, + )?; + }, + RngKind::Faster => { + log::info!("doing --faster WEIGHTED sampling..."); + let mut rng = FasterRng::create(seed); + do_weighted_sampling( + &mut rdr2.byte_records(), + wtr, + weight_column, + sample_size, + max_weight, + &mut rng, + )?; + }, + RngKind::Cryptosecure => { + log::info!("doing --cryptosecure WEIGHTED sampling..."); + let mut rng = CryptoRng::create(seed); + do_weighted_sampling( + &mut rdr2.byte_records(), + wtr, + weight_column, + sample_size, + max_weight, + &mut rng, + )?; + }, + } + + Ok(()) +} + +// Helper function to handle the actual sampling with any RNG type +fn do_weighted_sampling( + records: &mut impl Iterator>, + wtr: &mut csv::Writer, + weight_column: usize, + sample_size: usize, + max_weight: f64, + rng: &mut T, +) -> CliResult<()> { + use std::collections::HashSet; + + let mut selected = HashSet::with_capacity(sample_size); + let mut attempts = 0; + let max_attempts = sample_size * 100; // Prevent infinite loops + + while selected.len() < sample_size && attempts < max_attempts { + for (i, record) in records.enumerate() { + if selected.len() >= sample_size { + break; + } + + let record = record?; + let weight = String::from_utf8_lossy( + record + .get(weight_column) + .ok_or_else(|| format!("Weight column index {weight_column} out of bounds"))?, + ) + .parse::() + .map_err(|_| "Invalid weight value")?; + + if weight < 0.0 { + return fail_incorrectusage_clierror!("Weights must be non-negative"); + } + + // Modified acceptance-rejection method to handle zero weights + let include_flag = if weight == 0.0 { + false + } else { + rng.random::() <= (weight / max_weight) + }; + + if include_flag && !selected.contains(&i) { + selected.insert(i); + wtr.write_byte_record(&record)?; + } + + attempts += 1; + if attempts >= max_attempts { + break; + } + } + } + + if selected.len() < sample_size { + log::warn!( + "Could only sample {} records out of requested {}", + selected.len(), + sample_size + ); + } + + Ok(()) +} + +// Cluster sampling implementation +fn sample_cluster( + rconfig: &Config, + rdr: &mut csv::Reader, + wtr: &mut csv::Writer, + cluster_column: usize, + n_clusters: usize, + seed: Option, + rng_kind: &RngKind, +) -> CliResult<()> { + const ESTIMATED_CLUSTER_COUNT: usize = 100; + + // Use HashSet for faster lookups of unique clusters + let mut unique_clusters: HashSet> = HashSet::with_capacity(ESTIMATED_CLUSTER_COUNT); + let mut all_clusters: Vec> = Vec::with_capacity(ESTIMATED_CLUSTER_COUNT); + + // First pass: collect unique clusters + for record in rdr.byte_records() { + let record = record?; + let cluster = record + .get(cluster_column) + .ok_or_else(|| format!("Cluster column index {cluster_column} out of bounds"))? + .to_vec(); + + if unique_clusters.insert(cluster.clone()) { + all_clusters.push(cluster); + } + } + + if unique_clusters.is_empty() { + return fail_incorrectusage_clierror!("No valid clusters found in the data"); + } + + // Select clusters + let selected_clusters: HashSet> = match rng_kind { + RngKind::Standard => { + let mut rng = StandardRng::create(seed); + all_clusters + .choose_multiple(&mut rng, n_clusters.min(all_clusters.len())) + .cloned() + .collect() + }, + RngKind::Faster => { + let mut rng = FasterRng::create(seed); + all_clusters + .choose_multiple(&mut rng, n_clusters.min(all_clusters.len())) + .cloned() + .collect() + }, + RngKind::Cryptosecure => { + let mut rng = CryptoRng::create(seed); + all_clusters + .choose_multiple(&mut rng, n_clusters.min(all_clusters.len())) + .cloned() + .collect() + }, + }; + + // Second pass: output records from selected clusters + let mut rdr2 = rconfig.reader()?; + for record in rdr2.byte_records() { + let record = record?; + let cluster = record + .get(cluster_column) + .ok_or_else(|| format!("Cluster column index {cluster_column} out of bounds"))? + .to_vec(); + + if selected_clusters.contains(&cluster) { + wtr.write_byte_record(&record)?; + } + } + + Ok(()) +} From dd84939e2fd7a446e93e52b1cab36caa40898810 Mon Sep 17 00:00:00 2001 From: Joel Natividad <1980690+jqnatividad@users.noreply.github.com> Date: Sun, 9 Feb 2025 16:09:27 -0500 Subject: [PATCH 2/2] tests: `sample` add tests for all the new sampling methods --- tests/test_sample.rs | 470 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 470 insertions(+) diff --git a/tests/test_sample.rs b/tests/test_sample.rs index de8e3dbed..017dbabf8 100644 --- a/tests/test_sample.rs +++ b/tests/test_sample.rs @@ -517,6 +517,8 @@ fn sample_bernoulli_seed() { .arg("0.5") .arg("in.csv"); + wrk.assert_success(&mut cmd); + let got: Vec> = wrk.read_stdout(&mut cmd); let expected = vec![ svec!["R", "S"], @@ -615,3 +617,471 @@ fn sample_bernoulli_invalid_probability() { cmd.args(["--bernoulli"]).arg("-0.5").arg("in.csv"); wrk.assert_err(&mut cmd); } + +#[test] +fn sample_systematic() { + let wrk = Workdir::new("sample_systematic"); + wrk.create( + "in.csv", + vec![ + svec!["R", "S"], + svec!["1", "b"], + svec!["2", "a"], + svec!["3", "d"], + svec!["4", "c"], + svec!["5", "f"], + svec!["6", "e"], + svec!["7", "i"], + svec!["8", "h"], + svec!["9", "g"], + svec!["10", "j"], + ], + ); + + let mut cmd = wrk.command("sample"); + cmd.args(["--systematic", "first"]).arg("3").arg("in.csv"); + + wrk.assert_success(&mut cmd); + + let got: Vec> = wrk.read_stdout(&mut cmd); + let expected = vec![ + svec!["R", "S"], + svec!["1", "b"], + svec!["4", "c"], + svec!["7", "i"], + ]; + assert_eq!(got, expected); +} + +#[test] +fn sample_stratified() { + let wrk = Workdir::new("sample_stratified"); + wrk.create( + "in.csv", + vec![ + svec!["Group", "Value"], + svec!["A", "1"], + svec!["A", "2"], + svec!["A", "3"], + svec!["B", "4"], + svec!["B", "5"], + svec!["B", "6"], + svec!["C", "7"], + svec!["C", "8"], + svec!["C", "9"], + svec!["C", "10"], + svec!["C", "11"], + svec!["D", "12"], + ], + ); + + let mut cmd = wrk.command("sample"); + cmd.args(["--stratified", "Group"]) + .args(["--seed", "42"]) + .arg("2") + .arg("in.csv"); + + wrk.assert_success(&mut cmd); + + let got: Vec> = wrk.read_stdout(&mut cmd); + let expected = vec![ + svec!["Group", "Value"], + svec!["A", "3"], + svec!["A", "2"], + svec!["B", "4"], + svec!["B", "6"], + svec!["C", "9"], + svec!["C", "8"], + svec!["D", "12"], + ]; + assert_eq!(got, expected); +} + +#[test] +fn sample_stratified_large_sample_size() { + let wrk = Workdir::new("sample_stratified_large_sample_size"); + wrk.create( + "in.csv", + vec![ + svec!["Group", "Value"], + svec!["A", "1"], + svec!["A", "2"], + svec!["A", "3"], + svec!["B", "4"], + svec!["B", "5"], + svec!["B", "6"], + svec!["C", "7"], + svec!["C", "8"], + svec!["C", "9"], + svec!["C", "10"], + svec!["C", "11"], + svec!["D", "12"], + ], + ); + + let mut cmd = wrk.command("sample"); + cmd.args(["--stratified", "Group"]) + .args(["--seed", "42"]) + .arg("100") + .arg("in.csv"); + + wrk.assert_success(&mut cmd); + + let got: Vec> = wrk.read_stdout(&mut cmd); + let expected = vec![ + svec!["Group", "Value"], + svec!["A", "1"], + svec!["A", "2"], + svec!["A", "3"], + svec!["B", "4"], + svec!["B", "5"], + svec!["B", "6"], + svec!["C", "7"], + svec!["C", "8"], + svec!["C", "9"], + svec!["C", "10"], + svec!["C", "11"], + svec!["D", "12"], + ]; + assert_eq!(got, expected); +} + +#[test] +fn sample_weighted() { + let wrk = Workdir::new("sample_weighted"); + wrk.create( + "in.csv", + vec![ + svec!["ID", "Weight"], + svec!["1", "10"], + svec!["2", "20"], + svec!["3", "30"], + svec!["4", "40"], + svec!["5", "50"], + svec!["6", "60"], + svec!["7", "70"], + svec!["8", "80"], + svec!["9", "90"], + svec!["10", "100"], + ], + ); + + let mut cmd = wrk.command("sample"); + cmd.args(["--weighted", "ID"]) + .args(["--seed", "42"]) + .arg("4") + .arg("in.csv"); + + let got: Vec> = wrk.read_stdout(&mut cmd); + let expected = vec![ + svec!["ID", "Weight"], + svec!["5", "50"], + svec!["6", "60"], + svec!["9", "90"], + svec!["10", "100"], + ]; + assert_eq!(got, expected); +} + +#[test] +fn sample_cluster() { + let wrk = Workdir::new("sample_cluster"); + wrk.create( + "in.csv", + vec![ + svec!["Household", "Person", "Age"], + svec!["H1", "P1", "25"], + svec!["H1", "P2", "30"], + svec!["H1", "P3", "35"], + svec!["H2", "P3", "45"], + svec!["H2", "P4", "50"], + svec!["H2", "P5", "55"], + svec!["H3", "P5", "35"], + svec!["H3", "P6", "40"], + svec!["H3", "P7", "45"], + svec!["H4", "P7", "28"], + svec!["H4", "P8", "32"], + svec!["H4", "P9", "36"], + svec!["H4", "P10", "40"], + svec!["H5", "P11", "44"], + svec!["H5", "P12", "48"], + ], + ); + + let mut cmd = wrk.command("sample"); + cmd.args(["--cluster", "Household"]) + .args(["--seed", "42"]) + .arg("2") + .arg("in.csv"); + + let got: Vec> = wrk.read_stdout(&mut cmd); + let expected = vec![ + svec!["Household", "Person", "Age"], + svec!["H1", "P1", "25"], + svec!["H1", "P2", "30"], + svec!["H1", "P3", "35"], + svec!["H3", "P5", "35"], + svec!["H3", "P6", "40"], + svec!["H3", "P7", "45"], + ]; + assert_eq!(got, expected); +} + +#[test] +fn sample_stratified_invalid_column() { + let wrk = Workdir::new("sample_stratified_invalid"); + wrk.create( + "in.csv", + vec![svec!["Group", "Value"], svec!["A", "1"], svec!["B", "2"]], + ); + + let mut cmd = wrk.command("sample"); + cmd.args(["--stratified", "999"]).arg("1").arg("in.csv"); + + wrk.assert_err(&mut cmd); +} + +#[test] +fn sample_weighted_negative_weights() { + let wrk = Workdir::new("sample_weighted_negative"); + wrk.create( + "in.csv", + vec![ + svec!["ID", "Weight"], + svec!["1", "-10"], + svec!["2", "20"], + svec!["3", "30"], + svec!["4", "40"], + svec!["5", "-50"], + ], + ); + + let mut cmd = wrk.command("sample"); + cmd.args(["--weighted", "1"]).arg("1").arg("in.csv"); + + wrk.assert_err(&mut cmd); +} + +#[test] +fn sample_stratified_empty_stratum() { + let wrk = Workdir::new("sample_stratified_empty"); + wrk.create( + "in.csv", + vec![ + svec!["Group", "Value"], + svec!["A", "1"], + svec!["", "2"], // empty stratum + svec!["A", "3"], + svec!["B", "4"], + svec!["", "5"], // another empty stratum + ], + ); + + let mut cmd = wrk.command("sample"); + cmd.args(["--stratified", "Group"]) + .args(["--seed", "42"]) + .arg("2") + .arg("in.csv"); + + wrk.assert_success(&mut cmd); + + let got: Vec> = wrk.read_stdout(&mut cmd); + let expected = vec![ + svec!["Group", "Value"], + svec!["", "2"], + svec!["", "5"], + svec!["A", "1"], + svec!["A", "3"], + svec!["B", "4"], + ]; + assert_eq!(got, expected); +} + +#[test] +fn sample_weighted_zero_weights() { + let wrk = Workdir::new("sample_weighted_zero"); + wrk.create( + "in.csv", + vec![ + svec!["ID", "Weight"], + svec!["1", "0"], + svec!["2", "0"], + svec!["3", "30"], + svec!["4", "0"], + svec!["5", "50"], + ], + ); + + let mut cmd = wrk.command("sample"); + cmd.args(["--weighted", "Weight"]) + .args(["--seed", "42"]) + .arg("2") + .arg("in.csv"); + + let got: Vec> = wrk.read_stdout(&mut cmd); + let expected = vec![svec!["ID", "Weight"], svec!["3", "30"], svec!["5", "50"]]; + assert_eq!(got, expected); +} + +#[test] +fn sample_cluster_single_record() { + let wrk = Workdir::new("sample_cluster_single"); + wrk.create( + "in.csv", + vec![ + svec!["Cluster", "Value"], + svec!["A", "1"], // single record cluster + svec!["B", "2"], + svec!["B", "3"], + svec!["C", "4"], // single record cluster + ], + ); + + let mut cmd = wrk.command("sample"); + cmd.args(["--cluster", "Cluster"]) + .args(["--seed", "42"]) + .arg("2") + .arg("in.csv"); + + wrk.assert_success(&mut cmd); + + let got: Vec> = wrk.read_stdout(&mut cmd); + let expected = vec![ + svec!["Cluster", "Value"], + svec!["A", "1"], + svec!["B", "2"], + svec!["B", "3"], + ]; + assert_eq!(got, expected); +} + +#[test] +fn sample_systematic_with_headers() { + let wrk = Workdir::new("sample_systematic_headers"); + wrk.create( + "in.csv", + vec![ + svec!["Header1", "Header2"], // should be preserved + svec!["1", "a"], + svec!["2", "b"], + svec!["3", "c"], + svec!["4", "d"], + svec!["5", "e"], + svec!["6", "f"], + svec!["7", "g"], + svec!["8", "h"], + svec!["9", "i"], + ], + ); + + let mut cmd = wrk.command("sample"); + cmd.args(["--systematic", "first"]).arg("3").arg("in.csv"); + + let got: Vec> = wrk.read_stdout(&mut cmd); + let expected = vec![ + svec!["Header1", "Header2"], + svec!["1", "a"], + svec!["4", "d"], + svec!["7", "g"], + ]; + assert_eq!(got, expected); +} + +#[test] +fn sample_systematic_with_headers_random_with_seed() { + let wrk = Workdir::new("sample_systematic_headers_random_with_seed"); + wrk.create( + "in.csv", + vec![ + svec!["Header1", "Header2"], // should be preserved + svec!["1", "a"], + svec!["2", "b"], + svec!["3", "c"], + svec!["4", "d"], + svec!["5", "e"], + svec!["6", "f"], + svec!["7", "g"], + svec!["8", "h"], + svec!["9", "i"], + svec!["10", "j"], + ], + ); + + let mut cmd = wrk.command("sample"); + cmd.args(["--systematic", "random", "--seed", "65"]) + .arg("4") + .arg("in.csv"); + + let got: Vec> = wrk.read_stdout(&mut cmd); + let expected = vec![ + svec!["Header1", "Header2"], + svec!["5", "e"], + svec!["9", "i"], + ]; + assert_eq!(got, expected); +} + +#[test] +fn sample_systematic_no_headers() { + let wrk = Workdir::new("sample_systematic_no_headers"); + wrk.create( + "in.csv", + vec![ + svec!["1", "a"], + svec!["2", "b"], + svec!["3", "c"], + svec!["4", "d"], + svec!["5", "e"], + svec!["6", "f"], + svec!["7", "g"], + svec!["8", "h"], + svec!["9", "i"], + svec!["10", "j"], + ], + ); + + let mut cmd = wrk.command("sample"); + cmd.args(["--systematic", "first", "--no-headers"]) + .arg("3") + .arg("in.csv"); + + let got: Vec> = wrk.read_stdout(&mut cmd); + let expected = vec![svec!["1", "a"], svec!["4", "d"], svec!["7", "g"]]; + assert_eq!(got, expected); +} + +#[test] +fn sample_multiple_methods_error() { + let wrk = Workdir::new("sample_multiple_methods"); + wrk.create( + "in.csv", + vec![svec!["ID", "Value"], svec!["1", "a"], svec!["2", "b"]], + ); + + // Test combining bernoulli with systematic + let mut cmd = wrk.command("sample"); + cmd.args(["--bernoulli", "--systematic", "first"]) + .arg("0.5") + .arg("in.csv"); + wrk.assert_err(&mut cmd); + + // Test combining weighted with stratified + let mut cmd = wrk.command("sample"); + cmd.args(["--weighted", "ID", "--stratified", "ID"]) + .arg("1") + .arg("in.csv"); + wrk.assert_err(&mut cmd); +} + +#[test] +fn sample_invalid_rng() { + let wrk = Workdir::new("sample_invalid_rng"); + wrk.create( + "in.csv", + vec![svec!["ID", "Value"], svec!["1", "a"], svec!["2", "b"]], + ); + + let mut cmd = wrk.command("sample"); + cmd.args(["--rng", "invalid_rng"]).arg("1").arg("in.csv"); + wrk.assert_err(&mut cmd); +}