Skip to content

Commit

Permalink
fixed bugs where the negate function wasn't treated as a math expr an…
Browse files Browse the repository at this point in the history
…d improved the pretty printer when negates are displayed
  • Loading branch information
mikeizbicki committed Oct 1, 2015
1 parent 88ab16f commit 9437535
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 23 deletions.
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@ If you compile the linear package with Herbie, the code above gets rewritten to:

```
w :: Double -> Double -> Double
w far near = if far < (negate 1.7210442634149447e81)
then (far / (far - near)) * 2 * near
w far near = if far < -1.7210442634149447e81
then ((-2 * far) / (far - near)) * near
else if far < 8.364504563556443e16
then (far * 2) / ((far - near) / near)
else (far / (far - near)) * 2 * near
then -2 * far * (near / (far - near))
else ((-2 * far) / (far - near)) * near
```

This modified code is numerically stable.
Expand Down
2 changes: 2 additions & 0 deletions src/Herbie/MathExpr.hs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ monOpList =
, "sqrt"
, "abs"
, "size"
, "negate"
]

binOpList = [ "^", "**", "^^", "/", "-", "expt" ] ++ commutativeOpList
Expand Down Expand Up @@ -180,6 +181,7 @@ mathExpr2lisp :: MathExpr -> String
mathExpr2lisp = go
where
go (EBinOp op a1 a2) = "("++op++" "++go a1++" "++go a2++")"
go (EMonOp "negate" a) = "(- "++go a++")"
go (EMonOp op a) = "("++op++" "++go a++")"
go (EIf cond e1 e2) = "(if "++go cond++" "++go e1++" "++go e2++")"
go (ELeaf e) = e
Expand Down
6 changes: 5 additions & 1 deletion src/Herbie/MathInfo.hs
Original file line number Diff line number Diff line change
Expand Up @@ -74,16 +74,20 @@ pprMathInfo mathInfo = go 1 False $ getMathExpr mathInfo
else str
where
str = case e of
EMonOp op e1 -> op++" "++ go i True e1
-- EMonOp "negate" l@(ELit _) -> "-"++go i False l
EMonOp "negate" e1 -> "-"++go i False e1
EMonOp op e1 -> op++" "++go i True e1

EBinOp op e1 e2 -> go i parens1 e1++" "++op++" "++go i parens2 e2
where
parens1 = case e1 of
(EBinOp op' _ _) -> op/=op'
(EMonOp _ _) -> False
_ -> True

parens2 = case e2 of
(EBinOp op' _ _) -> op/=op' || not (op `elem` commutativeOpList)
(EMonOp _ _) -> False
_ -> True

ELit l -> if toRational (floor l) == l
Expand Down
4 changes: 0 additions & 4 deletions test/Tests.hs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ module Main

import SubHask

{-
--------------------------------------------------------------------------------

-- This section tests that Herbie obeys the code annotations
Expand Down Expand Up @@ -295,9 +294,6 @@ example87 x = exp x / sqrt(exp x - 1) * sqrt x
example88 x = (exp(x) - 1) / x

example89 x = sqrt(x + 2) - sqrt(x)
-}

--------------------------------------------------------------------------------

--------------------------------------------------------------------------------

Expand Down
29 changes: 15 additions & 14 deletions test/ValidRewrite.hs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
{-# LANGUAGE GADTs,RebindableSyntax,CPP,FlexibleContexts,FlexibleInstances,ConstraintKinds #-}
{-# LANGUAGE StandaloneDeriving,DeriveDataTypeable #-}
{-# OPTIONS_GHC -dcore-lint #-}
{-
- This test suite ensures that the rewrites that HerbiePlugin performs
- give the correct results.
Expand Down Expand Up @@ -27,11 +28,11 @@ test1b far near = -(2 * far * near) / (far - near)

{-# ANN test1c "NoHerbie" #-}
test1c :: Double -> Double -> Double
test1c far near = -(if far < (negate 1.7210442634149447e81)
then (far / (far - near)) * 2 * near
test1c far near = if far < -1.7210442634149447e81
then ((-2 * far) / (far - near)) * near
else if far < 8.364504563556443e16
then (far * 2) / ((far - near) / near)
else (far / (far - near)) * 2 * near)
then -2 * far * (near / (far - near))
else ((-2 * far) / (far - near)) * near

--------------------

Expand Down Expand Up @@ -85,16 +86,16 @@ atanh_ x = 0.5 * log ((1.0+x) / (1.0-x))
putStrLn ""

main = do
-- mkTest(test1a,test1b,-2e90,6)
-- mkTest(test1a,test1b,3,4)
-- mkTest(test1a,test1b,2e90,6)
--
-- mkTest(test1a,test1c,-2e90,6)
-- mkTest(test1a,test1c,3,4)
-- mkTest(test1a,test1c,2e90,6)
--
-- mkTest(test2a,test2b,1,2)
--
mkTest(test1a,test1b,-2e90,6)
mkTest(test1a,test1b,3,4)
mkTest(test1a,test1b,2e90,6)

mkTest(test1a,test1c,-2e90,6)
mkTest(test1a,test1c,3,4)
mkTest(test1a,test1c,2e90,6)

mkTest(test2a,test2b,1,2)

-- mkTest(test3a,test3b,(Quaternion 1 (V3 1 2 3)),(Quaternion 2 (V3 2 3 4)))

-- mkTestB(asinh,asinh_,5e-17::Complex Double)
Expand Down

0 comments on commit 9437535

Please sign in to comment.