Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Static library 2.4 #12

Merged
merged 44 commits into from
Oct 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
d7ed2c0
Add support for `Dict[str, Any]`
VivekPanyam Sep 24, 2023
ce27216
Fix minor typos in README.md
bionicles Dec 5, 2023
a0b8580
Merge pull request #826 from bionicles/patch-1
LaurentMazare Dec 7, 2023
2dcdac6
Document file formats in VarStore::save and load methods
necrashter Dec 19, 2023
e9ec104
add tch-rs fork implement result indexing
CallMeMSL Jan 2, 2024
f63558d
remove unwraps
CallMeMSL Jan 2, 2024
bf3afe0
fmt
CallMeMSL Jan 2, 2024
9ffb2ca
Merge pull request #829 from necrashter/save-load-docs
LaurentMazare Jan 6, 2024
a5f6ea8
Merge pull request #801 from VivekPanyam/fix_dict_bug
LaurentMazare Jan 6, 2024
12b97a6
Fix rustfmt + minor tweaks.
LaurentMazare Jan 6, 2024
653520d
Merge remote-tracking branch 'origin/main' into index_result
LaurentMazare Jan 6, 2024
7f41f93
Clippy fix.
LaurentMazare Jan 6, 2024
ca92737
Merge pull request #835 from LaurentMazare/index_result
LaurentMazare Jan 6, 2024
afd5cc6
Add the declarations file for PyTorch 2.2.0.
LaurentMazare Jan 18, 2024
2a70ecf
Merge pull request #839 from LaurentMazare/decl-2.2.0
LaurentMazare Jan 18, 2024
7b9ef7e
Update for torch 2.2.0
LaurentMazare Jan 18, 2024
cc28ce7
Update more crates.
LaurentMazare Jan 18, 2024
852b447
Fix the build names.
LaurentMazare Jan 30, 2024
2626465
Merge pull request #840 from LaurentMazare/torch-2.2.0
LaurentMazare Jan 30, 2024
4fc5709
Fixes for clippy 1.76.
LaurentMazare Feb 8, 2024
4cdee63
Merge pull request #847 from LaurentMazare/clippy-1.76
LaurentMazare Feb 8, 2024
d068b18
Temporary Fix for GLOG
Jark5455 Mar 4, 2024
27c184f
Add the declarations file for PyTorch 2.3.
LaurentMazare Apr 14, 2024
420e41d
Merge pull request #863 from LaurentMazare/decl-2.3
LaurentMazare Apr 14, 2024
a4c1432
Update for PyTorch 2.3.
LaurentMazare Apr 14, 2024
dd74e82
Update more packages.
LaurentMazare Apr 14, 2024
1bf6d7e
Update for pyo3 0.21.
Apr 18, 2024
3e86831
Merge pull request #868 from LaurentMazare/pyo3-0.21
LaurentMazare Apr 18, 2024
269ff36
Merge branch 'main' into torch-2.3.0
LaurentMazare Apr 24, 2024
a353eac
Merge pull request #864 from LaurentMazare/torch-2.3.0
LaurentMazare Apr 24, 2024
37d90c3
Merge pull request #852 from Jark5455/main
LaurentMazare Apr 27, 2024
f4a4eef
Clippy fixes for 1.78.
LaurentMazare May 3, 2024
7e3a4bf
Another clippy fix.
LaurentMazare May 3, 2024
a90854b
Merge pull request #869 from LaurentMazare/clippy-fixes
LaurentMazare May 3, 2024
d6db322
Add the PyTorch 2.4 declarations file.
LaurentMazare Jul 15, 2024
e26845d
Bump the versions.
LaurentMazare Jul 15, 2024
16c4bb6
Update the bindings.
LaurentMazare Jul 15, 2024
c326289
Couple fixes.
LaurentMazare Jul 15, 2024
3df9ece
Get the crate to compile and the tests to pass.
LaurentMazare Jul 15, 2024
0d6c6b2
Clippy fixes.
LaurentMazare Jul 15, 2024
a4e9362
Merge pull request #878 from LaurentMazare/2.4
LaurentMazare Jul 24, 2024
f5f0630
Merge remote-tracking branch 'upstream/main' into update_v2.4
juhofuriosa Sep 4, 2024
7d6ac77
Fix clippy error
juhofuriosa Sep 5, 2024
10a1ed7
Adjust libtorch static library list and option
juhofuriosa Sep 4, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,18 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## Unreleased
### Changed

## v0.16.0
### Changed
- PyTorch v2.4 support

## v0.16.0
### Changed
- PyTorch v2.3 support

## v0.15.0
### Changed
- PyTorch v2.2 support

## v0.14.0
### Changed
- PyTorch v2.1 support
Expand Down
4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "tch"
version = "0.14.0"
version = "0.17.0"
authors = ["Laurent Mazare <[email protected]>"]
edition = "2021"
build = "build.rs"
Expand All @@ -23,7 +23,7 @@ ndarray = "0.15"
num-traits = "0.2.15"
rand = "0.8"
thiserror = "1"
torch-sys = { version = "0.14.0", path = "torch-sys" }
torch-sys = { version = "0.17.0", path = "torch-sys" }
zip = { version = "0.6", default-features = false, features = [
# Any other features, in particular `deflate-<backend>` features
# may be added by other crates and override the default backend.
Expand Down
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ The code generation part for the C api on top of libtorch comes from

## Getting Started

This crate requires the C++ PyTorch library (libtorch) in version *v2.1.0* to be available on
This crate requires the C++ PyTorch library (libtorch) in version *v2.4.0* to be available on
your system. You can either:

- Use the system-wide libtorch installation (default).
Expand Down Expand Up @@ -54,9 +54,9 @@ export LIBTORCH=/path/to/libtorch
The header files location can also be specified separately from the shared library via
the following:
```bash
# LIBTORCH_INCLUDE must contains `include` directory.
# LIBTORCH_INCLUDE must contain `include` directory.
export LIBTORCH_INCLUDE=/path/to/libtorch/
# LIBTORCH_LIB must contains `lib` directory.
# LIBTORCH_LIB must contain `lib` directory.
export LIBTORCH_LIB=/path/to/libtorch/
```
- For Windows users, assuming that `X:\path\to\libtorch` is the unzipped libtorch directory.
Expand Down Expand Up @@ -85,7 +85,7 @@ seem to include `libtorch.a` by default so this would have to be compiled
manually, e.g. via the following:

```bash
git clone -b v2.1.0 --recurse-submodule https://github.com/pytorch/pytorch.git pytorch-static --depth 1
git clone -b v2.4.0 --recurse-submodule https://github.com/pytorch/pytorch.git pytorch-static --depth 1
cd pytorch-static
USE_CUDA=OFF BUILD_SHARED_LIBS=OFF python setup.py build
# export LIBTORCH to point at the build directory in pytorch-static.
Expand Down
2 changes: 1 addition & 1 deletion examples/min-gpt/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ fn causal_self_attention(p: &nn::Path, cfg: Config) -> impl ModuleT {
let q = xs.apply(&query).view(sizes).transpose(1, 2);
let v = xs.apply(&value).view(sizes).transpose(1, 2);
let att = q.matmul(&k.transpose(-2, -1)) * (1.0 / f64::sqrt(sizes[3] as f64));
let att = att.masked_fill(&mask.i((.., .., ..sz_t, ..sz_t)).eq(0.), std::f64::NEG_INFINITY);
let att = att.masked_fill(&mask.i((.., .., ..sz_t, ..sz_t)).eq(0.), f64::NEG_INFINITY);
let att = att.softmax(-1, Kind::Float).dropout(cfg.attn_pdrop, train);
let ys = att.matmul(&v).transpose(1, 2).contiguous().view([sz_b, sz_t, sz_c]);
ys.apply(&proj).dropout(cfg.resid_pdrop, train)
Expand Down
8 changes: 4 additions & 4 deletions examples/python-extension/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ name = "tch_ext"
crate-type = ["cdylib"]

[dependencies]
pyo3 = { version = "0.18.3", features = ["extension-module"] }
pyo3-tch = { path = "../../pyo3-tch", version = "0.14.0" }
tch = { path = "../..", features = ["python-extension"], version = "0.14.0" }
torch-sys = { path = "../../torch-sys", features = ["python-extension"], version = "0.14.0" }
pyo3 = { version = "0.21", features = ["extension-module"] }
pyo3-tch = { path = "../../pyo3-tch", version = "0.17.0" }
tch = { path = "../..", features = ["python-extension"], version = "0.17.0" }
torch-sys = { path = "../../torch-sys", features = ["python-extension"], version = "0.17.0" }
6 changes: 3 additions & 3 deletions examples/python-extension/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@ use pyo3_tch::{wrap_tch_err, PyTensor};

#[pyfunction]
fn add_one(tensor: PyTensor) -> PyResult<PyTensor> {
let tensor = tensor.f_add_scalar(1.0).map_err(wrap_tch_err)?;
let tensor = tensor.f_add_scalar(1.0, 1.0).map_err(wrap_tch_err)?;
maxstate marked this conversation as resolved.
Show resolved Hide resolved
Ok(PyTensor(tensor))
}

/// A Python module implemented in Rust using tch to manipulate PyTorch
/// objects.
#[pymodule]
fn tch_ext(py: Python<'_>, m: &PyModule) -> PyResult<()> {
py.import("torch")?;
fn tch_ext(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
py.import_bound("torch")?;
m.add_function(wrap_pyfunction!(add_one, m)?)?;
Ok(())
}
1 change: 0 additions & 1 deletion examples/stable-diffusion/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
//
// cargo run --release --example tensor-tools cp ./data/vae.npz ./data/vae.ot
// cargo run --release --example tensor-tools cp ./data/unet.npz ./data/unet.ot
///
// TODO: fix tensor_tools so that it works properly there.
// TODO: Split this file, probably in a way similar to huggingface/diffusers.
use std::collections::{HashMap, HashSet};
Expand Down
6 changes: 3 additions & 3 deletions examples/yolo/darknet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ struct Block {

impl Block {
fn get(&self, key: &str) -> Result<&str> {
match self.parameters.get(&key.to_string()) {
match self.parameters.get(key) {
None => bail!("cannot find {} in {}", key, self.block_type),
Some(value) => Ok(value),
}
Expand All @@ -32,7 +32,7 @@ pub struct Darknet {

impl Darknet {
fn get(&self, key: &str) -> Result<&str> {
match self.parameters.get(&key.to_string()) {
match self.parameters.get(key) {
None => bail!("cannot find {} in net parameters", key),
Some(value) => Ok(value),
}
Expand Down Expand Up @@ -199,7 +199,7 @@ where
slice.copy_(&src)
}

fn detect(xs: &Tensor, image_height: i64, classes: i64, anchors: &Vec<(i64, i64)>) -> Tensor {
fn detect(xs: &Tensor, image_height: i64, classes: i64, anchors: &[(i64, i64)]) -> Tensor {
let (bsize, _channels, height, _width) = xs.size4().unwrap();
let stride = image_height / height;
let grid_size = image_height / stride;
Expand Down
8 changes: 6 additions & 2 deletions gen/gen.ml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ let excluded_functions =
; "_cummax_helper"
; "retain_grad"
; "_validate_sparse_coo_tensor_args"
; "_sparse_semi_structured_addmm"
; "_backward"
; "size"
; "stride"
Expand Down Expand Up @@ -101,7 +102,9 @@ let excluded_prefixes =
; "_amp_foreach"
; "_nested_tensor"
; "_fused_adam"
; "_fused_adagrad"
; "sym_"
; "_fused_sgd"
]

let excluded_suffixes = [ "_forward"; "_forward_out" ]
Expand Down Expand Up @@ -176,6 +179,7 @@ module Func = struct
| "at::tensoroptions" -> Some TensorOptions
| "at::intarrayref" -> Some (if is_nullable then IntListOption else IntList)
| "at::arrayref<double>" -> Some DoubleList
| "const c10::list<::std::optional<at::tensor>> &"
| "const c10::list<c10::optional<at::tensor>> &" -> Some TensorOptList
| "const at::itensorlistref &" | "at::tensorlist" -> Some TensorList
| "at::device" -> Some Device
Expand Down Expand Up @@ -598,7 +602,7 @@ let write_cpp funcs filename =
let pc s = p out_cpp s in
let ph s = p out_h s in
pc "// THIS FILE IS AUTOMATICALLY GENERATED, DO NOT EDIT BY HAND!";
pc "#include \"%s.h\"" (Caml.Filename.basename filename);
pc "#include \"%s.h\"" (Stdlib.Filename.basename filename);
pc "";
ph "// THIS FILE IS AUTOMATICALLY GENERATED, DO NOT EDIT BY HAND!";
ph "#include \"torch_api.h\"";
Expand Down Expand Up @@ -887,7 +891,7 @@ let run

let () =
run
~yaml_filename:"third_party/pytorch/Declarations-v2.1.0.yaml"
~yaml_filename:"third_party/pytorch/Declarations-v2.4.0.yaml"
~cpp_filename:"torch-sys/libtch/torch_api_generated"
~ffi_filename:"torch-sys/src/c_generated.rs"
~wrapper_filename:"src/wrappers/tensor_generated.rs"
Expand Down
8 changes: 4 additions & 4 deletions pyo3-tch/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "pyo3-tch"
version = "0.14.0"
version = "0.17.0"
authors = ["Laurent Mazare <[email protected]>"]
edition = "2021"
build = "build.rs"
Expand All @@ -12,6 +12,6 @@ categories = ["science"]
license = "MIT/Apache-2.0"

[dependencies]
tch = { path = "..", features = ["python-extension"], version = "0.14.0" }
torch-sys = { path = "../torch-sys", features = ["python-extension"], version = "0.14.0" }
pyo3 = { version = "0.18.3", features = ["extension-module"] }
tch = { path = "..", features = ["python-extension"], version = "0.17.0" }
torch-sys = { path = "../torch-sys", features = ["python-extension"], version = "0.17.0" }
pyo3 = { version = "0.21", features = ["extension-module"] }
5 changes: 1 addition & 4 deletions pyo3-tch/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
use pyo3::exceptions::{PyTypeError, PyValueError};
use pyo3::prelude::*;
use pyo3::{
exceptions::{PyTypeError, PyValueError},
AsPyPointer,
};
pub use tch;
pub use torch_sys;

Expand Down
4 changes: 2 additions & 2 deletions src/nn/init.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ pub fn f_init(i: Init, dims: &[i64], device: Device) -> Result<Tensor, TchError>
// Optimize the case for which a single C++ code can be done.
if cst == 0. {
Tensor::f_zeros(dims, (Kind::Float, device))
} else if (cst - 1.).abs() <= std::f64::EPSILON {
} else if (cst - 1.).abs() <= f64::EPSILON {
Tensor::f_ones(dims, (Kind::Float, device))
} else {
Tensor::f_ones(dims, (Kind::Float, device)).map(|t| t * cst)
Expand All @@ -117,7 +117,7 @@ pub fn f_init(i: Init, dims: &[i64], device: Device) -> Result<Tensor, TchError>
Tensor::f_zeros(dims, (Kind::Float, device))?.f_uniform_(lo, up)
}
Init::Randn { mean, stdev } => {
if mean == 0. && (stdev - 1.).abs() <= std::f64::EPSILON {
if mean == 0. && (stdev - 1.).abs() <= f64::EPSILON {
Tensor::f_randn(dims, (Kind::Float, device))
} else {
Tensor::f_randn(dims, (Kind::Float, device)).map(|t| t * stdev + mean)
Expand Down
10 changes: 10 additions & 0 deletions src/nn/var_store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,11 @@ impl VarStore {
///
/// Weight values for all the tensors currently stored in the
/// var-store are saved in the given file.
///
/// If the given path ends with the suffix `.safetensors`, the file will
/// be saved in safetensors format. Otherwise, libtorch C++ module format
/// will be used. Note that saving in pickle format (`.pt` extension) is
/// not supported by the C++ API of Torch.
pub fn save<T: AsRef<std::path::Path>>(&self, path: T) -> Result<(), TchError> {
let variables = self.variables_.lock().unwrap();
let named_tensors = variables.named_variables.iter().collect::<Vec<_>>();
Expand Down Expand Up @@ -216,6 +221,11 @@ impl VarStore {
/// var-store are loaded from the given file. Note that the set of
/// variables stored in the var-store is not changed, only the values
/// for these tensors are modified.
///
/// The format of the file is deduced from the file extension:
/// - `.safetensors`: The file is assumed to be in safetensors format.
/// - `.bin` or `.pt`: The file is assumed to be in pickle format.
/// - Otherwise, the file is assumed to be in libtorch C++ module format.
pub fn load<T: AsRef<std::path::Path>>(&mut self, path: T) -> Result<(), TchError> {
if self.device != Device::Mps {
self.load_internal(path)
Expand Down
Loading
Loading