Skip to content

Commit

Permalink
Merge pull request #34 from glcraft/fixes
Browse files Browse the repository at this point in the history
Update 0.8.1
  • Loading branch information
glcraft authored Nov 16, 2023
2 parents 755f929 + f76caab commit 69627f9
Show file tree
Hide file tree
Showing 8 changed files with 268 additions and 251 deletions.
File renamed without changes.
2 changes: 1 addition & 1 deletion src/generators/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
pub mod openai;
pub mod debug;
pub mod from_file;

use tokio_stream::Stream;
use thiserror::Error;
Expand Down
157 changes: 0 additions & 157 deletions src/generators/openai/flatten_stream.rs

This file was deleted.

103 changes: 11 additions & 92 deletions src/generators/openai/mod.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
pub mod config;
pub mod credentials;
mod flatten_stream;
use bytes::Bytes;
use flatten_stream::FlattenTrait;

use serde::{Serialize, Deserialize};
use tokio_stream::StreamExt;
use crate::args;
use crate::{
args,
utils::{
SplitBytesFactory,
FlattenTrait
}
};
use self::config::Prompt;

use super::{ResultRun, Error};
Expand Down Expand Up @@ -199,89 +202,6 @@ impl ChatResponse {
}
}

struct SplitBytesFactory<Sep>
where
Sep: AsRef<[u8]>
{
separator: Sep,
rest: Vec<u8>,
}

impl<Sep> SplitBytesFactory<Sep>
where
Sep: AsRef<[u8]> + Clone
{
fn new(separator: Sep) -> Self {
Self {
separator,
rest: Vec::new(),
}
}
fn new_iter(&mut self, bytes: Bytes) -> SplitBytes<Sep> {
let sep_len = self.separator.as_ref().len();
let pos_last_separator = bytes.len() - (sep_len + bytes
.windows(self.separator.as_ref().len())
.rev()
.position(|b| b == self.separator.as_ref())
.unwrap_or(bytes.len()));

let mut current = Vec::new();
std::mem::swap(&mut current, &mut self.rest);
current.append(&mut bytes.slice(..pos_last_separator).to_vec());
self.rest = bytes.slice((pos_last_separator + sep_len)..).to_vec();
SplitBytes::new(Bytes::from(current), self.separator.clone())
}
}

struct SplitBytes<Sep>
where
Sep: AsRef<[u8]>
{
bytes: Bytes,
separator: Sep,
index: Option<usize>,
}

impl<Sep> SplitBytes<Sep>
where
Sep: AsRef<[u8]>
{
fn new(bytes: Bytes, separator: Sep) -> Self {
Self {
bytes,
separator,
index: Some(0),
}
}
}

impl<Sep> Iterator for SplitBytes<Sep>
where
Sep: AsRef<[u8]>
{
type Item = Bytes;
fn next(&mut self) -> Option<Self::Item> {
let separator = self.separator.as_ref();
let index = self.index?;
let bytes = self.bytes.slice(index..);
let found = bytes
.windows(separator.len())
.find(|b| b == &separator);
let slice_bytes = if let Some(found) = found {
let end_selection = found.as_ptr() as usize - bytes.as_ptr() as usize;
self.index = self.index.map(|i| i + end_selection + found.len());
bytes.slice(..end_selection)
} else {
self.index = None;
bytes
};
match slice_bytes.is_empty() {
false => Some(slice_bytes),
true => None,
}
}
}

pub async fn run(creds: credentials::Credentials, config: crate::config::Config, args: args::ProcessedArgs) -> ResultRun {
let openai_api_key = creds.api_key;

Expand Down Expand Up @@ -331,17 +251,16 @@ pub async fn run(creds: credentials::Credentials, config: crate::config::Config,
.expect("Failed to open log file")
)
});
LOG.lock().and_then(|mut log|{
if let Ok(mut log) = LOG.lock() {
log.write_all(&input)
.and_then(|_| log.write_all(b"\n---\n"))
.expect("Debug: Failed to write to log file");
Ok(())
});
.expect("Debug: Failed to write to log file");
}
}

Ok(split_bytes_factory.new_iter(input))
})
.flatten_stream()
.flatten_result_iter()
.map(|v| {
let v = v?;
let chat_resp = ChatResponse::from_bytes(v);
Expand Down
3 changes: 2 additions & 1 deletion src/main.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
pub mod arguments;
mod utils;
mod config;
mod credentials;
mod filesystem;
Expand Down Expand Up @@ -69,7 +70,7 @@ async fn main() -> Result<(), String> {

let mut stream = match engine {
"openai" => generators::openai::run(creds.openai, config, args).await,
"from-file" => generators::debug::run(config, args).await,
"from-file" => generators::from_file::run(config, args).await,
_ => panic!("Unknown engine: {}", engine),
}
.map_err(|e| format!("Failed to request OpenAI API: {}", e))?;
Expand Down
Loading

0 comments on commit 69627f9

Please sign in to comment.