Skip to content

Commit

Permalink
style: Reformat some code
Browse files Browse the repository at this point in the history
  • Loading branch information
aseyboldt authored and lucianopaz committed Feb 11, 2025
1 parent 630a092 commit d15fba1
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 13 deletions.
2 changes: 1 addition & 1 deletion python/nutpie/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
from nutpie.sample import sample

__version__: str = _lib.__version__
__all__ = ["__version__", "sample", "compile_pymc_model", "compile_stan_model"]
__all__ = ["__version__", "compile_pymc_model", "compile_stan_model", "sample"]
8 changes: 6 additions & 2 deletions python/nutpie/compile_pymc.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import dataclasses
import itertools
import warnings
from collections.abc import Iterable
from dataclasses import dataclass
from functools import wraps
from importlib.util import find_spec
from math import prod
from typing import TYPE_CHECKING, Any, Callable, Iterable, Literal, Optional, Union
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -500,7 +501,10 @@ def compile_pymc_model(
if gradient_backend == "jax":
raise ValueError("Gradient backend cannot be jax when using numba backend")
return _compile_pymc_model_numba(
model=model, pymc_initial_point_fn=initial_point_fn, var_names=var_names, **kwargs
model=model,
pymc_initial_point_fn=initial_point_fn,
var_names=var_names,
**kwargs,
)
elif backend.lower() == "jax":
return _compile_pymc_model_jax(
Expand Down
6 changes: 3 additions & 3 deletions src/pyfunc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ use std::sync::Arc;
use anyhow::{anyhow, bail, Context, Result};
use arrow::{
array::{
Array, ArrayBuilder, BooleanBuilder, FixedSizeListBuilder, Float32Builder, Float64Builder,
Int64Builder, LargeListBuilder, ListBuilder, PrimitiveBuilder, StructBuilder,
Array, ArrayBuilder, BooleanBuilder, Float32Builder, Float64Builder, Int64Builder,
LargeListBuilder, PrimitiveBuilder, StructBuilder,
},
datatypes::{DataType, Field, Float32Type, Float64Type, Int64Type},
};
Expand All @@ -16,7 +16,7 @@ use pyo3::{
Bound, Py, PyAny, PyErr, Python,
};
use rand::Rng;
use rand_distr::{Distribution, StandardNormal, Uniform};
use rand_distr::{Distribution, Uniform};
use smallvec::SmallVec;
use thiserror::Error;

Expand Down
5 changes: 1 addition & 4 deletions src/pymc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@ use std::{ffi::c_void, fmt::Display, sync::Arc};

use anyhow::{bail, Context, Result};
use arrow::{
array::{
Array, FixedSizeListArray, Float64Array, LargeListArray, LargeListBuilder, StructArray,
},
array::{Array, Float64Array, LargeListArray, StructArray},
buffer::OffsetBuffer,
datatypes::{DataType, Field, Fields},
};
Expand All @@ -16,7 +14,6 @@ use pyo3::{
types::{PyAnyMethods, PyList},
Bound, Py, PyAny, PyObject, PyResult, Python,
};
use rand::{distributions::Uniform, prelude::Distribution};

use thiserror::Error;

Expand Down
15 changes: 12 additions & 3 deletions tests/test_pymc.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,10 @@ def test_pymc_var_names(backend, gradient_backend):
pm.Deterministic("c", mu * b)

compiled = nutpie.compile_pymc_model(
model, backend=backend, gradient_backend=gradient_backend, var_names=None,
model,
backend=backend,
gradient_backend=gradient_backend,
var_names=None,
)
trace = nutpie.sample(compiled, chains=1, seed=1)

Expand All @@ -213,7 +216,10 @@ def test_pymc_var_names(backend, gradient_backend):
assert hasattr(trace.posterior, "c")

compiled = nutpie.compile_pymc_model(
model, backend=backend, gradient_backend=gradient_backend, var_names=[],
model,
backend=backend,
gradient_backend=gradient_backend,
var_names=[],
)
trace = nutpie.sample(compiled, chains=1, seed=1)

Expand All @@ -222,7 +228,10 @@ def test_pymc_var_names(backend, gradient_backend):
assert not hasattr(trace.posterior, "c")

compiled = nutpie.compile_pymc_model(
model, backend=backend, gradient_backend=gradient_backend, var_names=["b"],
model,
backend=backend,
gradient_backend=gradient_backend,
var_names=["b"],
)
trace = nutpie.sample(compiled, chains=1, seed=1)

Expand Down

0 comments on commit d15fba1

Please sign in to comment.