Skip to content

Commit

Permalink
Streamline the glm4 example.
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare committed Dec 31, 2024
1 parent 460616f commit f2f73ae
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 147 deletions.
6 changes: 5 additions & 1 deletion candle-examples/examples/flux/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,11 @@ fn run(args: Args) -> Result<()> {
};
println!("img\n{img}");
let img = ((img.clamp(-1f32, 1f32)? + 1.0)? * 127.5)?.to_dtype(candle::DType::U8)?;
candle_examples::save_image(&img.i(0)?, "out.jpg")?;
let filename = match args.seed {
None => "out.jpg".to_string(),
Some(s) => format!("out-{s}.jpg"),
};
candle_examples::save_image(&img.i(0)?, filename)?;
Ok(())
}

Expand Down
39 changes: 8 additions & 31 deletions candle-examples/examples/glm4/README.org
Original file line number Diff line number Diff line change
Expand Up @@ -7,48 +7,25 @@ GLM-4-9B is the open-source version of the latest generation of pre-trained mode
** Running with ~cuda~

#+begin_src shell
cargo run --example glm4 --release --features cuda
cargo run --example glm4 --release --features cuda -- --prompt "Hello world"
#+end_src

** Running with ~cpu~
#+begin_src shell
cargo run --example glm4 --release -- --cpu
cargo run --example glm4 --release -- --cpu--prompt "Hello world"
#+end_src

** Output Example
#+begin_src shell
cargo run --example glm4 --release --features cuda -- --sample-len 500 --cache .
Finished release [optimized] target(s) in 0.24s
Running `/root/candle/target/release/examples/glm4 --sample-len 500 --cache .`
cargo run --features cuda -r --example glm4 -- --prompt "Hello "

avx: true, neon: false, simd128: false, f16c: true
temp: 0.60 repeat-penalty: 1.20 repeat-last-n: 64
cache path .
retrieved the files in 6.88963ms
loaded the model in 6.113752297s
retrieved the files in 6.454375ms
loaded the model in 3.652383779s
starting the inference loop
[欢迎使用GLM-4,请输入prompt]
请你告诉我什么是FFT
266 tokens generated (34.50 token/s)
Result:
。Fast Fourier Transform (FFT) 是一种快速计算离散傅里叶变换(DFT)的方法,它广泛应用于信号处理、图像处理和数据分析等领域。

具体来说,FFT是一种将时域数据转换为频域数据的算法。在数字信号处理中,我们通常需要知道信号的频率成分,这就需要进行傅立叶变换。传统的傅立叶变换的计算复杂度较高,而 FFT 则大大提高了计算效率,使得大规模的 DFT 换成为可能。

以下是使用 Python 中的 numpy 进行 FFT 的简单示例:

```python
import numpy as np

# 创建一个时域信号
t = np.linspace(0, 1, num=100)
f = np.sin(2*np.pi*5*t) + 3*np.cos(2*np.pi*10*t)

# 对该信号做FFT变换,并计算其幅值谱
fft_result = np.fft.fftshift(np.abs(np.fft.fft(f)))

```

在这个例子中,我们首先创建了一个时域信号 f。然后我们对这个信号进行了 FFT 换,得到了一个频域结果 fft_result。
Hello 2018, hello new year! I’m so excited to be back and sharing with you all my favorite things from the past month. This is a monthly series where I share what’s been inspiring me lately in hopes that it will inspire you too!
...
#+end_src

This example will read prompt from stdin
Expand Down
201 changes: 86 additions & 115 deletions candle-examples/examples/glm4/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,120 +12,97 @@ struct TextGeneration {
device: Device,
tokenizer: Tokenizer,
logits_processor: LogitsProcessor,
repeat_penalty: f32,
repeat_last_n: usize,
verbose_prompt: bool,
args: Args,
dtype: DType,
}

impl TextGeneration {
#[allow(clippy::too_many_arguments)]
fn new(
model: Model,
tokenizer: Tokenizer,
seed: u64,
temp: Option<f64>,
top_p: Option<f64>,
repeat_penalty: f32,
repeat_last_n: usize,
verbose_prompt: bool,
device: &Device,
dtype: DType,
) -> Self {
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
fn new(model: Model, tokenizer: Tokenizer, args: Args, device: &Device, dtype: DType) -> Self {
let logits_processor = LogitsProcessor::new(args.seed, args.temperature, args.top_p);
Self {
model,
tokenizer,
logits_processor,
repeat_penalty,
repeat_last_n,
verbose_prompt,
args,
device: device.clone(),
dtype,
}
}

fn run(&mut self, sample_len: usize) -> anyhow::Result<()> {
use std::io::BufRead;
use std::io::BufReader;
fn run(&mut self) -> anyhow::Result<()> {
use std::io::Write;
let args = &self.args;
println!("starting the inference loop");
println!("[欢迎使用GLM-4,请输入prompt]");
let stdin = std::io::stdin();
let reader = BufReader::new(stdin);
for line in reader.lines() {
let line = line.expect("Failed to read line");

let tokens = self.tokenizer.encode(line, true).expect("tokens error");
if tokens.is_empty() {
panic!("Empty prompts are not supported in the chatglm model.")
}
if self.verbose_prompt {
for (token, id) in tokens.get_tokens().iter().zip(tokens.get_ids().iter()) {
let token = token.replace('▁', " ").replace("<0x0A>", "\n");
println!("{id:7} -> '{token}'");
}

let tokens = self
.tokenizer
.encode(args.prompt.to_string(), true)
.expect("tokens error");
if tokens.is_empty() {
panic!("Empty prompts are not supported in the chatglm model.")
}
if args.verbose {
for (token, id) in tokens.get_tokens().iter().zip(tokens.get_ids().iter()) {
let token = token.replace('▁', " ").replace("<0x0A>", "\n");
println!("{id:7} -> '{token}'");
}
let eos_token = match self.tokenizer.get_vocab(true).get("<|endoftext|>") {
Some(token) => *token,
None => panic!("cannot find the endoftext token"),
} else {
print!("{}", &args.prompt);
std::io::stdout().flush()?;
}
let eos_token = match self.tokenizer.get_vocab(true).get("<|endoftext|>") {
Some(token) => *token,
None => panic!("cannot find the endoftext token"),
};
let mut tokens = tokens.get_ids().to_vec();
let mut generated_tokens = 0usize;

std::io::stdout().flush().expect("output flush error");
let start_gen = std::time::Instant::now();

for index in 0..args.sample_len {
let context_size = if index > 0 { 1 } else { tokens.len() };
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
let logits = self.model.forward(&input)?;
let logits = logits.squeeze(0)?.to_dtype(self.dtype)?;
let logits = if args.repeat_penalty == 1. {
logits
} else {
let start_at = tokens.len().saturating_sub(args.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
args.repeat_penalty,
&tokens[start_at..],
)?
};
let mut tokens = tokens.get_ids().to_vec();
let mut generated_tokens = 0usize;

std::io::stdout().flush().expect("output flush error");
let start_gen = std::time::Instant::now();

let mut count = 0;
let mut result = vec![];
for index in 0..sample_len {
count += 1;
let context_size = if index > 0 { 1 } else { tokens.len() };
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
let logits = self.model.forward(&input)?;
let logits = logits.squeeze(0)?.to_dtype(self.dtype)?;
let logits = if self.repeat_penalty == 1. {
logits
} else {
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
self.repeat_penalty,
&tokens[start_at..],
)?
};

let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;
if next_token == eos_token {
break;
}
let token = self
.tokenizer
.decode(&[next_token], true)
.expect("Token error");
if self.verbose_prompt {
println!(
"[Count: {}] [Raw Token: {}] [Decode Token: {}]",
count, next_token, token
);
}
result.push(token);
std::io::stdout().flush()?;

let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;
if next_token == eos_token {
break;
}
let dt = start_gen.elapsed();
println!(
"\n{generated_tokens} tokens generated ({:.2} token/s)",
generated_tokens as f64 / dt.as_secs_f64(),
);
println!("Result:");
for tokens in result {
print!("{tokens}");
let token = self
.tokenizer
.decode(&[next_token], true)
.expect("token decode error");
if args.verbose {
println!(
"[Count: {}] [Raw Token: {}] [Decode Token: {}]",
generated_tokens, next_token, token
);
} else {
print!("{token}");
std::io::stdout().flush()?;
}
self.model.reset_kv_cache(); // clean the cache
}
let dt = start_gen.elapsed();
println!(
"\n{generated_tokens} tokens generated ({:.2} token/s)",
generated_tokens as f64 / dt.as_secs_f64(),
);
Ok(())
}
}
Expand All @@ -141,7 +118,11 @@ struct Args {

/// Display the token for the specified prompt.
#[arg(long)]
verbose_prompt: bool,
prompt: String,

/// Display the tokens for the specified prompt and outputs.
#[arg(long)]
verbose: bool,

/// The temperature used to generate samples.
#[arg(long)]
Expand Down Expand Up @@ -197,28 +178,29 @@ fn main() -> anyhow::Result<()> {
);

let start = std::time::Instant::now();
println!("cache path {}", args.cache_path);
let api = hf_hub::api::sync::ApiBuilder::from_cache(hf_hub::Cache::new(args.cache_path.into()))
.build()
.map_err(anyhow::Error::msg)?;
let api = hf_hub::api::sync::ApiBuilder::from_cache(hf_hub::Cache::new(
args.cache_path.to_string().into(),
))
.build()
.map_err(anyhow::Error::msg)?;

let model_id = match args.model_id {
let model_id = match args.model_id.as_ref() {
Some(model_id) => model_id.to_string(),
None => "THUDM/glm-4-9b".to_string(),
};
let revision = match args.revision {
let revision = match args.revision.as_ref() {
Some(rev) => rev.to_string(),
None => "main".to_string(),
};
let repo = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
let tokenizer_filename = match args.tokenizer {
let tokenizer_filename = match args.tokenizer.as_ref() {
Some(file) => std::path::PathBuf::from(file),
None => api
.model("THUDM/codegeex4-all-9b".to_string())
.get("tokenizer.json")
.map_err(anyhow::Error::msg)?,
};
let filenames = match args.weight_file {
let filenames = match args.weight_file.as_ref() {
Some(weight_file) => vec![std::path::PathBuf::from(weight_file)],
None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
};
Expand All @@ -238,18 +220,7 @@ fn main() -> anyhow::Result<()> {

println!("loaded the model in {:?}", start.elapsed());

let mut pipeline = TextGeneration::new(
model,
tokenizer,
args.seed,
args.temperature,
args.top_p,
args.repeat_penalty,
args.repeat_last_n,
args.verbose_prompt,
&device,
dtype,
);
pipeline.run(args.sample_len)?;
let mut pipeline = TextGeneration::new(model, tokenizer, args, &device, dtype);
pipeline.run()?;
Ok(())
}

0 comments on commit f2f73ae

Please sign in to comment.