Skip to content

Commit

Permalink
all variant
Browse files Browse the repository at this point in the history
  • Loading branch information
joelberkeley committed Jan 20, 2024
1 parent 86f08cd commit 58b6a9f
Show file tree
Hide file tree
Showing 10 changed files with 58 additions and 1,398 deletions.
4 changes: 4 additions & 0 deletions src/Compiler/Expr.idr
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ export
empty : Env
empty = MkEnv 0 []

export
addNode1 : Expr -> Env -> (Nat, Env)
addNode1 x (MkEnv n xs) = (n, MkEnv (S n) ((n, x) :: xs))

export
addNode : Expr -> State Env Nat
addNode expr = do
Expand Down
62 changes: 42 additions & 20 deletions src/Tensor.idr
Original file line number Diff line number Diff line change
Expand Up @@ -110,44 +110,66 @@ panicIO x = runEitherT x <&> \case
||| **Note:**
||| * Each call to `eval` will rebuild and execute the graph. Similarly, multiple calls to
||| `eval` on different `Tensor`s in a computation will be treated entirely independently.
||| `eval` does not store intermediate values. This is a known limitation, and may change in
||| the future.
||| * `eval` performs logging. You can disable this by adjusting the TensorFlow logging level
||| with e.g. `export TF_CPP_MIN_LOG_LEVEL=3`.
||| `eval` does not store intermediate values. If you want to evaluate multiple tensors, use
||| `Tuple.eval`.
||| * `eval` performs logging. You can disable this by adjusting the logging level
||| with e.g. `export TF_CPP_MIN_LOG_LEVEL=3`.
export partial
eval : PrimitiveRW dtype ty => Graph (Tensor shape dtype) -> IO (Literal shape ty)
eval $ MkGraph x =
let (env, MkTensor root) = runState empty x
in panicIO $ execute (MkFn [] root env) >>= read {dtype} []

namespace TensorList
public export
data TensorList : List (Shape, Type) -> Type where
Nil : TensorList []
(::) : PrimitiveRW dtype ty =>
Tensor shape dtype ->
TensorList sts ->
TensorList ((shape, ty) :: sts)

namespace Tuple
export partial
eval : Graph (TensorList shapes) -> IO $ All (uncurry Literal) shapes
eval $ MkGraph tensors = do
eval : Graph (All2 Tensor shapes dtypes) ->
All2 PrimitiveRW dtypes tys =>
IO $ All2 Literal shapes tys
eval @{prims} $ MkGraph tensors = do
let graph = do ts <- tensors
x <- addNode (Tuple $ nodes ts)
pure (x, ts)
(env, root, tensors) = runState empty graph
panicIO $ execute (MkFn [] root env) >>= readAll tensors 0
panicIO $ execute (MkFn [] root env) >>= readAll tensors prims 0

where

nodes : TensorList s -> List Nat
nodes : All2 Tensor ss ds -> List Nat
nodes [] = []
nodes (MkTensor x :: xs) = x :: nodes xs

readAll : HasIO io => TensorList s -> Nat -> Literal -> io $ All (uncurry Literal) s
readAll [] _ _ = pure []
readAll (MkTensor {dtype} _ :: ts) n lit = [| read {dtype} [n] lit :: readAll ts (S n) lit |]
readAll : HasIO io =>
All2 Tensor ss ds ->
All2 PrimitiveRW ds ts ->
Nat ->
Literal ->
io $ All2 Literal ss ts
readAll [] [] _ _ = pure []
readAll (MkTensor {dtype} _ :: ts) (prim :: prims) n lit =
[| read {dtype} [n] lit :: readAll ts prims (S n) lit |]

partial
foo : IO Bool
foo = do
let x0 : Literal [] Double
y0 := tensor {dtype = F64} x0
[x0'] <- eval {tys = %search} (do pure [!y0])
pure (x0 == x0')

namespace Example
interface RW (a : Type) (b : Type) | a where

RW Nat Int32 where
RW Bool Double where

eval : All2 Literal shape as -> All2 RW as bs => All2 Literal shape bs

eq : Bool
eq = let xs : Literal [2] Int32
xs' : Literal [2] Nat

[xs''] := eval [xs']
in xs == xs''
||| A string representation of the graph used to define a `Tensor`, detailing all enqueued XLA
||| operations.
Expand Down
6 changes: 6 additions & 0 deletions src/Util.idr
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,12 @@ namespace List
impl _ [] = []
impl i (x :: xs) = if elem i idxs then impl (S i) xs else x :: impl (S i) xs

namespace All2
public export
data All2 : (0 p : a -> b -> Type) -> List a -> List b -> Type where
Nil : All2 p [] []
(::) : forall xs, ys . p x y -> All2 p xs ys -> All2 p (x :: xs) (y :: ys)

||| A `Sorted f xs` proves that for all consecutive elements `x` and `y` in `xs`, `f x y` exists.
||| For example, a `Sorted LT xs` proves that all `Nat`s in `xs` appear in increasing numerical
||| order.
Expand Down
6 changes: 0 additions & 6 deletions test.ipkg
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,6 @@ main = Main
modules =
Unit.Model.TestKernel,

Unit.TestTensor.Elementwise,
Unit.TestTensor.HigherOrder,
Unit.TestTensor.Sampling,
Unit.TestTensor.Slice,
Unit.TestTensor.Structure,

Unit.TestDistribution,
Unit.TestLiteral,
Unit.TestTensor,
Expand Down
23 changes: 6 additions & 17 deletions test/Unit/TestTensor.idr
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,6 @@ limitations under the License.
--}
module Unit.TestTensor

import Unit.TestTensor.Elementwise
import Unit.TestTensor.HigherOrder
import Unit.TestTensor.Sampling
import Unit.TestTensor.Slice
import Unit.TestTensor.Structure

import Data.Nat
import Data.Vect
import System
Expand Down Expand Up @@ -67,9 +61,9 @@ evalTuple = property $ do
y1 = tensor {dtype = S32} x1
y2 = tensor {dtype = U64} x2

let [] = unsafePerformIO $ eval (pure [])
-- let [] = unsafePerformIO $ eval {tys = []} (pure [])

let [x0'] = unsafePerformIO $ eval (do pure [!y0])
let [x0'] = unsafePerformIO $ eval {tys = [_]} (do pure [!y0])

x0' ==~ x0

Expand All @@ -87,14 +81,15 @@ evalTuple = property $ do
partial
evalTupleNonTrivial : Property
evalTupleNonTrivial = property $ do
let xs = do y0 <- tensor [1.0, -2.0, 0.4]
let xs : Graph $ All2 Tensor [[], [2]] _ =
do y0 <- tensor [1.0, -2.0, 0.4]
y1 <- tensor 3.0
u <- exp y0
v <- slice [at 1] u + pure y1
w <- slice [0.to 2] u
pure [v, w]

[v, w] = unsafePerformIO $ eval xs
[v, w] = unsafePerformIO $ eval {shapes = [[], [2]]} {tys = [_, _]} xs

v ==~ Scalar (exp (-2.0) + 3.0)
w ==~ [| exp [1.0, -2.0] |]
Expand Down Expand Up @@ -414,10 +409,4 @@ group = MkGroup "Tensor" $ [
, (#"(|\) and (/|) result and inverse"#, triangularSolveResultAndInverse)
, (#"(|\) and (/|) ignore opposite elements"#, triangularSolveIgnoresOppositeElems)
, ("trace", trace)
] ++ concat (the (List _) [
Unit.TestTensor.Elementwise.all
, Unit.TestTensor.HigherOrder.all
, Unit.TestTensor.Sampling.all
, Unit.TestTensor.Slice.all
, Unit.TestTensor.Structure.all
])
]
Loading

0 comments on commit 58b6a9f

Please sign in to comment.