From 6c8107a0f640b6723823792f4cb7395dee40a620 Mon Sep 17 00:00:00 2001
From: Adam Binford <adamq43@gmail.com>
Date: Sun, 14 Jan 2024 11:50:37 -0500
Subject: [PATCH] Rework some things to simplify

---
 crates/hdfs-native/src/file.rs              |  41 +++---
 crates/hdfs-native/src/hdfs/block_writer.rs | 134 ++++++++++----------
 crates/hdfs-native/src/hdfs/connection.rs   |   1 +
 3 files changed, 84 insertions(+), 92 deletions(-)

diff --git a/crates/hdfs-native/src/file.rs b/crates/hdfs-native/src/file.rs
index baf671e..7df9071 100644
--- a/crates/hdfs-native/src/file.rs
+++ b/crates/hdfs-native/src/file.rs
@@ -188,12 +188,17 @@ impl FileWriter {
             }
         } else {
             // Not appending to an existing block, just create a new one
+            // If there's an existing block writer, close it first
+            let extended_block = if let Some(block_writer) = self.block_writer.take() {
+                let extended_block = block_writer.get_extended_block();
+                block_writer.close().await?;
+                Some(extended_block)
+            } else {
+                None
+            };
+
             self.protocol
-                .add_block(
-                    &self.src,
-                    self.block_writer.as_ref().map(|b| b.get_extended_block()),
-                    self.status.file_id,
-                )
+                .add_block(&self.src, extended_block, self.status.file_id)
                 .await?
                 .block
         };
@@ -216,16 +221,8 @@ impl FileWriter {
     }
 
     async fn get_block_writer(&mut self) -> Result<&mut BlockWriter> {
-        // If the current writer is full, close it
-        if let Some(block_writer) = self.block_writer.as_mut() {
-            if block_writer.is_full() {
-                block_writer.close().await?;
-                self.create_block_writer().await?;
-            }
-        }
-
-        // If we haven't created a writer yet, create one
-        if self.block_writer.is_none() {
+        // If the current writer is full, or hasn't been created, create one
+        if self.block_writer.as_ref().is_some_and(|b| b.is_full()) || self.block_writer.is_none() {
             self.create_block_writer().await?;
         }
 
@@ -248,20 +245,20 @@ impl FileWriter {
 
     pub async fn close(&mut self) -> Result<()> {
         if !self.closed {
-            if let Some(block_writer) = self.block_writer.as_mut() {
+            let extended_block = if let Some(block_writer) = self.block_writer.take() {
+                let extended_block = block_writer.get_extended_block();
                 block_writer.close().await?;
-            }
+                Some(extended_block)
+            } else {
+                None
+            };
 
             let mut retry_delay = COMPLETE_RETRY_DELAY_MS;
             let mut retries = 0;
             while retries < COMPLETE_RETRIES {
                 let successful = self
                     .protocol
-                    .complete(
-                        &self.src,
-                        self.block_writer.as_ref().map(|b| b.get_extended_block()),
-                        self.status.file_id,
-                    )
+                    .complete(&self.src, extended_block.clone(), self.status.file_id)
                     .await?
                     .result;
 
diff --git a/crates/hdfs-native/src/hdfs/block_writer.rs b/crates/hdfs-native/src/hdfs/block_writer.rs
index 18e779c..7c030a7 100644
--- a/crates/hdfs-native/src/hdfs/block_writer.rs
+++ b/crates/hdfs-native/src/hdfs/block_writer.rs
@@ -3,7 +3,7 @@ use std::time::Duration;
 use bytes::{BufMut, Bytes, BytesMut};
 use futures::future::join_all;
 use log::debug;
-use tokio::sync::{mpsc, oneshot};
+use tokio::{sync::mpsc, task::JoinHandle};
 
 use crate::{
     ec::{gf256::Coder, EcSchema},
@@ -67,7 +67,7 @@ impl BlockWriter {
         }
     }
 
-    pub(crate) async fn close(&mut self) -> Result<()> {
+    pub(crate) async fn close(self) -> Result<()> {
         match self {
             Self::Replicated(writer) => writer.close().await,
             Self::Striped(writer) => writer.close().await,
@@ -85,7 +85,10 @@ pub(crate) struct ReplicatedBlockWriter {
 
     // Tracks the state of acknowledgements. Set to an Err if any error occurs doing receiving
     // acknowledgements. Set to Ok(()) when the last acknowledgement is received.
-    status: Option<oneshot::Receiver<Result<()>>>,
+    ack_listener_handle: JoinHandle<Result<()>>,
+    // Tracks the state of packet sender. Set to Err if any error occurs during writing packets,
+    packet_sender_handle: JoinHandle<Result<DatanodeWriter>>,
+
     ack_queue: mpsc::Sender<(i64, bool)>,
     packet_sender: mpsc::Sender<Packet>,
 }
@@ -140,13 +143,11 @@ impl ReplicatedBlockWriter {
 
         // Channel for tracking packets that need to be acked
         let (ack_queue_sender, ack_queue_receiever) = mpsc::channel::<(i64, bool)>(100);
-        // Channel for tracking errors that occur listening for acks or successful ack of the last packet
-        let (status_sender, status_receiver) = oneshot::channel::<Result<()>>();
         let (packet_sender, packet_receiver) = mpsc::channel::<Packet>(100);
 
-        Self::listen_for_acks(reader, ack_queue_receiever, status_sender);
-        Self::start_packet_sender(writer, packet_receiver);
-        Self::start_heartbeat(packet_sender.clone());
+        let ack_listener_handle = Self::listen_for_acks(reader, ack_queue_receiever);
+        let packet_sender_handle = Self::start_packet_sender(writer, packet_receiver);
+        Self::start_heartbeat_sender(packet_sender.clone());
 
         let bytes_per_checksum = server_defaults.bytes_per_checksum;
         let write_packet_size = server_defaults.write_packet_size;
@@ -173,7 +174,10 @@ impl ReplicatedBlockWriter {
             server_defaults,
             next_seqno: 1,
             current_packet,
-            status: Some(status_receiver),
+
+            ack_listener_handle,
+            packet_sender_handle,
+
             ack_queue: ack_queue_sender,
             packet_sender,
         };
@@ -218,16 +222,17 @@ impl ReplicatedBlockWriter {
     }
 
     fn check_error(&mut self) -> Result<()> {
-        if let Some(status) = self.status.as_mut() {
-            match status.try_recv() {
-                Ok(result) => result?,
-                Err(oneshot::error::TryRecvError::Empty) => (),
-                Err(oneshot::error::TryRecvError::Closed) => {
-                    return Err(HdfsError::DataTransferError(
-                        "Status channel closed prematurely".to_string(),
-                    ))
-                }
-            }
+        // If either task is finished, something went wrong
+        if self.ack_listener_handle.is_finished() {
+            return Err(HdfsError::DataTransferError(
+                "Ack listener finished prematurely".to_string(),
+            ));
+        }
+
+        if self.packet_sender_handle.is_finished() {
+            return Err(HdfsError::DataTransferError(
+                "Packet sender finished prematurely".to_string(),
+            ));
         }
 
         Ok(())
@@ -267,7 +272,7 @@ impl ReplicatedBlockWriter {
     }
 
     /// Send a packet with any remaining data and then send a last packet
-    async fn close(&mut self) -> Result<()> {
+    async fn close(mut self) -> Result<()> {
         self.check_error()?;
 
         // Send a packet with any remaining data
@@ -279,45 +284,35 @@ impl ReplicatedBlockWriter {
         self.current_packet.set_last_packet();
         self.send_current_packet().await?;
 
-        // Wait for the channel to close, meaning all acks have been received or an error occured
-        if let Some(status) = self.status.take() {
-            let result = status.await.map_err(|_| {
-                HdfsError::DataTransferError(
-                    "Status channel closed while waiting for final ack".to_string(),
-                )
-            })?;
-            result?;
-        } else {
-            return Err(HdfsError::DataTransferError(
-                "Block already closed".to_string(),
-            ));
-        }
+        // Wait for all packets to be sent
+        self.packet_sender_handle.await.map_err(|_| {
+            HdfsError::DataTransferError(
+                "Packet sender task err while waiting for packets to send".to_string(),
+            )
+        })??;
 
-        Ok(())
+        // Wait for the channel to close, meaning all acks have been received or an error occured
+        self.ack_listener_handle.await.map_err(|_| {
+            HdfsError::DataTransferError(
+                "Ack status channel closed while waiting for final ack".to_string(),
+            )
+        })?
     }
 
     fn listen_for_acks(
         mut reader: DatanodeReader,
         mut ack_queue: mpsc::Receiver<(i64, bool)>,
-        status: oneshot::Sender<Result<()>>,
-    ) {
+    ) -> JoinHandle<Result<()>> {
         tokio::spawn(async move {
             loop {
-                let next_ack = match reader.read_ack().await {
-                    Ok(next_ack) => next_ack,
-                    Err(e) => {
-                        let _ = status.send(Err(e));
-                        break;
-                    }
-                };
+                let next_ack = reader.read_ack().await?;
 
                 for reply in next_ack.reply.iter() {
                     if *reply != hdfs::Status::Success as i32 {
-                        let _ = status.send(Err(HdfsError::DataTransferError(format!(
+                        return Err(HdfsError::DataTransferError(format!(
                             "Received non-success status in datanode ack: {:?}",
                             hdfs::Status::from_i32(*reply)
-                        ))));
-                        return;
+                        )));
                     }
                 }
 
@@ -325,57 +320,56 @@ impl ReplicatedBlockWriter {
                     continue;
                 }
                 if next_ack.seqno == UNKNOWN_SEQNO {
-                    let _ = status.send(Err(HdfsError::DataTransferError(
+                    return Err(HdfsError::DataTransferError(
                         "Received unknown seqno for successful ack".to_string(),
-                    )));
-                    break;
+                    ));
                 }
 
                 if let Some((seqno, last_packet)) = ack_queue.recv().await {
                     if next_ack.seqno != seqno {
-                        let _ = status.send(Err(HdfsError::DataTransferError(
+                        return Err(HdfsError::DataTransferError(
                             "Received acknowledgement does not match expected sequence number"
                                 .to_string(),
-                        )));
-                        break;
+                        ));
                     }
 
                     if last_packet {
-                        let _ = status.send(Ok(()));
-                        break;
+                        return Ok(());
                     }
                 } else {
-                    let _ = status.send(Err(HdfsError::DataTransferError(
+                    return Err(HdfsError::DataTransferError(
                         "Channel closed while getting next seqno to acknowledge".to_string(),
-                    )));
-                    break;
+                    ));
                 }
             }
-        });
+        })
     }
 
     fn start_packet_sender(
         mut writer: DatanodeWriter,
         mut packet_receiver: mpsc::Receiver<Packet>,
-    ) {
+    ) -> JoinHandle<Result<DatanodeWriter>> {
         tokio::spawn(async move {
             while let Some(mut packet) = packet_receiver.recv().await {
-                match writer.write_packet(&mut packet).await {
-                    Ok(_) if packet.header.last_packet_in_block => break,
-                    Ok(_) => (),
-                    Err(_) => panic!(),
+                writer.write_packet(&mut packet).await?;
+
+                if packet.header.last_packet_in_block {
+                    break;
                 }
             }
-            writer
-        });
+            Ok(writer)
+        })
     }
 
-    fn start_heartbeat(packet_sender: mpsc::Sender<Packet>) {
+    fn start_heartbeat_sender(packet_sender: mpsc::Sender<Packet>) {
         tokio::spawn(async move {
             loop {
                 tokio::time::sleep(Duration::from_secs(HEARTBEAT_INTERVAL_SECONDS)).await;
                 let heartbeat_packet = Packet::empty(0, HEART_BEAT_SEQNO, 0, 0);
-                let _ = packet_sender.send(heartbeat_packet).await;
+                // If this fails, sending anymore data packets will generate an error as well
+                if packet_sender.send(heartbeat_packet).await.is_err() {
+                    break;
+                }
             }
         });
     }
@@ -560,15 +554,15 @@ impl StripedBlockWriter {
         Ok(())
     }
 
-    async fn close(&mut self) -> Result<()> {
+    async fn close(mut self) -> Result<()> {
         if !self.cell_buffer.is_empty() {
             self.write_cells().await?;
         }
 
         let close_futures = self
             .block_writers
-            .iter_mut()
-            .filter_map(|writer| writer.as_mut())
+            .into_iter()
+            .filter_map(|mut writer| writer.take())
             .map(|writer| async move { writer.close().await });
 
         for close_result in join_all(close_futures).await {
diff --git a/crates/hdfs-native/src/hdfs/connection.rs b/crates/hdfs-native/src/hdfs/connection.rs
index e0c5e07..3879fb6 100644
--- a/crates/hdfs-native/src/hdfs/connection.rs
+++ b/crates/hdfs-native/src/hdfs/connection.rs
@@ -600,6 +600,7 @@ impl DatanodeConnection {
         (reader, writer)
     }
 
+    // For future use where we cache datanode connections
     #[allow(dead_code)]
     pub(crate) fn reunite(reader: DatanodeReader, writer: DatanodeWriter) -> Self {
         assert_eq!(reader.client_name, writer.client_name);