Skip to content

Commit

Permalink
fully generify numerical precision
Browse files Browse the repository at this point in the history
  • Loading branch information
breandan committed Jan 24, 2020
1 parent b65d9ad commit 838f832
Show file tree
Hide file tree
Showing 9 changed files with 377 additions and 8,754 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -374,8 +374,23 @@ sealed class Protocol<X : SFun<X>> {
val x = Var<X>("x")
val y = Var<X>("y")
val z = Var<X>("z")

val variables = listOf(x, y, z)

val zero = Zero<X>()
val one = One<X>()
val two = Two<X>()
val e = E<X>()

val constants: Map<Special<X>, Number> = mapOf(zero to 0, one to 1, two to 2, e to E)

private fun <T: Map<Var<X>, Number>> T.bind() =
Bindings((constants + this@bind).map { it.key to wrap(it.value) }.toMap())

abstract fun wrap(number: Number): X
fun <X: RealNumber<X, Y>, Y: Number> SFun<X>.unwap() = (this as X).value
fun <X: RealNumber<X, Y>, Y: Number> SFun<X>.toDouble() = unwap().toDouble()

fun <T: SFun<T>> sin(angle: SFun<T>) = angle.sin()
fun <T: SFun<T>> cos(angle: SFun<T>) = angle.cos()
fun <T: SFun<T>> tan(angle: SFun<T>) = angle.tan()
Expand All @@ -388,19 +403,13 @@ sealed class Protocol<X : SFun<X>> {
infix operator fun div(arg: Differential<X>) = fx.d(arg.fx.bindings.sVars.first())
}

abstract val constants: List<Pair<Special<X>, Number>>

protected fun <T: Pair<Var<X>, Number>> Array<T>.bind() =
Bindings((constants + this@bind).map { it.first to wrap(it.second) }.toMap())
operator fun SFun<X>.invoke(vararg pairs: Pair<Var<X>, Number>) = this(pairs.toMap().bind())

operator fun SFun<X>.invoke(vararg pairs: Pair<Var<X>, Number>) = this(pairs.bind())
operator fun <E : D1> VFun<X, E>.invoke(vararg pairs: Pair<Var<X>, Number>) = this(pairs.toMap().bind())

operator fun <Y : D1> VFun<X, Y>.invoke(vararg pairs: Pair<Var<X>, Number>) = this(pairs.bind())

operator fun <Rows : D1, Cols: D1> MFun<X, Rows, Cols>.invoke(vararg pairs: Pair<Var<X>, Number>) = this(pairs.bind())
operator fun <R : D1, C: D1> MFun<X, R, C>.invoke(vararg pairs: Pair<Var<X>, Number>) = this(pairs.toMap().bind())

fun d(fn: SFun<X>) = Differential(fn)
abstract fun wrap(default: Number): X

operator fun Number.times(multiplicand: SFun<X>) = wrap(this) * multiplicand
operator fun SFun<X>.times(multiplicand: Number) = this * wrap(multiplicand)
Expand Down Expand Up @@ -449,27 +458,9 @@ sealed class Protocol<X : SFun<X>> {
}

object DoublePrecision : Protocol<DReal>() {
override fun wrap(default: Number): DReal = DReal(default.toDouble())

override val constants: List<Pair<Special<DReal>, Number>> = listOf(
Zero<DReal>() to 0,
One<DReal>() to 1,
Two<DReal>() to 2,
E<DReal>() to E
)

fun SFun<DReal>.asDouble() = (this as DReal).value
override fun wrap(number: Number): DReal = DReal(number.toDouble())
}

object BigDecimalPrecision : Protocol<BDReal>() {
override fun wrap(default: Number): BDReal = BDReal(default.toDouble())

override val constants: List<Pair<Special<BDReal>, Number>> = listOf(
Zero<BDReal>() to 0,
One<BDReal>() to 1,
Two<BDReal>() to 2,
E<BDReal>() to E
)

fun SFun<BDReal>.asDouble() = (this as BDReal).value.toDouble()
override fun wrap(number: Number): BDReal = BDReal(number.toDouble())
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class TestArithmetic: StringSpec({

"test division" {
DoubleGenerator(0).assertAll { ẋ, ẏ ->
(x / y)(x to ẋ, y to ẏ) shouldBeAbout x(x to ẋ).asDouble() / y(y to ẏ).asDouble()
(x / y)(x to ẋ, y to ẏ) shouldBeAbout x(x to ẋ).toDouble() / y(y to ẏ).toDouble()
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ fun main() {
val `dy∕dx` = d(y) / d(x)

xs.map { BigDecimal(it) }.run {
val f = map { y(x to it).asDouble() }
val df = map { `dy∕dx`(x to it).asDouble() }
val f = map { y(x to it).toDouble() }
val df = map { `dy∕dx`(x to it).toDouble() }
arrayOf(f, df)
}.map { it.toDoubleArray() }.toTypedArray()
}
Expand All @@ -55,7 +55,7 @@ fun main() {
// dy/dx=$`dy∕dx`
// """.trimIndent())

xs.run { arrayOf(map { y(x to it).asDouble() }, map { `dy∕dx`(x to it).asDouble() }) }.map { it.toDoubleArray() }.toTypedArray()
xs.run { arrayOf(map { y(x to it).toDouble() }, map { `dy∕dx`(x to it).toDouble() }) }.map { it.toDoubleArray() }.toTypedArray()
}

// Numerical differentiation using centered differences
Expand All @@ -64,7 +64,7 @@ fun main() {
val h = 7E-13
val `dy∕dx` = (y(x to x + h) - y(x to x - h)) / (2.0 * h)

xs.run { arrayOf(map { y(x to it).asDouble() }, map { `dy∕dx`(x to it).asDouble() }) }.map { it.toDoubleArray() }.toTypedArray()
xs.run { arrayOf(map { y(x to it).toDouble() }, map { `dy∕dx`(x to it).toDouble() }) }.map { it.toDoubleArray() }.toTypedArray()
}

val t = { i: Double, d: Double -> log10(abs(i - d)).let { if (it < -20) -20.0 else it } }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package edu.umontreal.kotlingrad.samples
import edu.umontreal.kotlingrad.numerical.DoublePrecision

@Suppress("NonAsciiCharacters", "LocalVariableName")
fun main() {
fun main() =
with(DoublePrecision) {
val x = Var("x")
val y = Var("y")
Expand All @@ -25,5 +25,4 @@ fun main() {
"∂²z($values)/∂x² \t\t= $`∂z∕∂y` \n\t\t\t\t= " + `∂²z∕∂x²`(values) + "\n" +
"∂²z($values)/∂x∂y \t\t= $`∂²z∕∂x∂y` \n\t\t\t\t= " + `∂²z∕∂x∂y`(values) + "\n" +
"∇z($values) \t\t\t= $`∇z` \n\t\t\t\t= [${`∇z`[x]!!(values)}, ${`∇z`[y]!!(values)}]ᵀ")
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import jetbrains.letsPlot.geom.geom_path
import jetbrains.letsPlot.ggplot
import jetbrains.letsPlot.ggtitle
import jetbrains.letsPlot.intern.toSpec
import java.io.File
import kotlin.math.*

fun main() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ class MultilayerPerceptron<T: SFun<T>>(
return lossFun(p1v to p1)(p2v to p2)(p3v to p3)
}


private fun layer(x: VFun<T, D3>): VFun<T, D3> = x.map { sigmoid(it) }

companion object {
Expand Down Expand Up @@ -51,7 +50,7 @@ class MultilayerPerceptron<T: SFun<T>>(
val (X, Y) = drawSample()
val m = mlp(p1 = w1, p2 = w2, p3 = w3)

totalLoss += m(mlp.x to X, mlp.y to Y).asDouble()
totalLoss += m(mlp.x to X, mlp.y to Y).toDouble()

val dw1 = m.d(mlp.p1v)
val dw2 = m.d(mlp.p2v)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ private fun DoublePrecision.plot2D(range: ClosedFloatingPointRange<Double>,
vararg funs: SFun<DReal>): String {
val labels = arrayOf("y", "dy/dx", "d²y/x²", "d³y/dx³", "d⁴y/dx⁴", "d⁵y/dx⁵")
val xs = (range step 0.0087).toList()
val ys = funs.map { xs.map { xv -> it(x to xv).asDouble() } }
val ys = funs.map { xs.map { xv -> it(x to xv).toDouble() } }
val data = (labels.zip(ys) + ("x" to xs)).toMap()
val colors = listOf("dark_green", "gray", "black", "red", "orange", "dark_blue")
val geoms = labels.zip(colors).map { geom_path(size = 2.0, color = it.second) { x = "x"; y = it.first } }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class DoublePendulum(private val len: Double = 900.0) : Application(), EventHand
val m1 = 2.0 // Masses
val m2 = 2.0
var G: SFun<DReal> = DoublePrecision.wrap(9.81) // Gravity
var µ: SFun<DReal> = DoublePrecision.wrap(0.01) // Friction
var µ = 0.01 // Friction
val Gp = 0.01 // Simulate measurement error
val µp = -0.01
var r1 = DoublePrecision.Vec(1.0, 0.0) // Polar vector
Expand Down Expand Up @@ -64,15 +64,19 @@ class DoublePendulum(private val len: Double = 900.0) : Application(), EventHand
if(isObserving && obs != null) {
val loss =2 - obs) pow 2
G = loss.descend(100, 0.0, 0.9, 0.1, G as Var<DReal> to wrap(priorVal))
priorVal = G.asDouble()
priorVal = G.toDouble()
println(G)
}

val r1a = (r1.angle +1 * dt + .5 * α1 * dt * dt)).run {
if(G is Var) this(G as Var<DReal> to priorVal).asDouble() else try { asDouble() } catch(e: Exception) {println(this); this(bindings.sVars.first() to priorVal).asDouble() }
if(G is Var) this(G as Var<DReal> to priorVal).toDouble() else try {
toDouble()
} catch(e: Exception) {println(this); this(bindings.sVars.first() to priorVal).toDouble() }
}
val r2a = (r2.angle + -2 * dt + .5 * α2 * dt * dt)).run {
if(G is Var) this(G as Var<DReal> to priorVal).asDouble() else try { asDouble() } catch(e: Exception) {println(this); this(bindings.sVars.first() to priorVal).asDouble() }
if(G is Var) this(G as Var<DReal> to priorVal).toDouble() else try {
toDouble()
} catch(e: Exception) {println(this); this(bindings.sVars.first() to priorVal).toDouble() }
}

if(G is Var) {
Expand Down Expand Up @@ -100,7 +104,7 @@ class DoublePendulum(private val len: Double = 900.0) : Application(), EventHand
velocity = gamma * velocity + d_dg(map.first to G1P) * α
G1P -= velocity
i++
} while (abs(velocity.asDouble()) > 0.00001 && i < steps)
} while (abs(velocity.toDouble()) > 0.00001 && i < steps)

return G1P
}
Expand All @@ -126,12 +130,14 @@ class DoublePendulum(private val len: Double = 900.0) : Application(), EventHand

override fun start(stage: Stage) {
twin = DoublePendulum().apply {
G += DoublePrecision.wrap(Gp) // Perturb gravity and friction
µ += DoublePrecision.wrap(µp)
with(DoublePrecision) {
G += Gp // Perturb gravity and friction
µ += µp

// TODO: Erase these parameters -- should be fully learned
// TODO: Erase these parameters -- should be fully learned
// µ = 0.0
// G = 0.0
}
}

rod3 = twin.rod1
Expand All @@ -151,15 +157,15 @@ class DoublePendulum(private val len: Double = 900.0) : Application(), EventHand
}

val Vec<DReal, D2>.r: Double
get() = DoublePrecision.run { this@r[0].asDouble() }
get() = DoublePrecision.run { this@r[0].toDouble() }
val Vec<DReal, D2>.theta: Double
get() = DoublePrecision.run { this@theta[1].asDouble() }
get() = DoublePrecision.run { this@theta[1].toDouble() }

val Vec<DReal, D2>.end: Vec<DReal, D2>
get() = DoublePrecision.run { Vec(this@end.r + rodLen * magn * cos(angle), this@end.theta - rodLen * magn * sin(angle)) }

val Vec<DReal, D2>.magn: Double
get() = DoublePrecision.run { magnitude().asDouble() }
get() = DoublePrecision.run { magnitude().toDouble() }

val Vec<DReal, D2>.angle: Double
get() = DoublePrecision.run { atan2(this@angle.theta, this@angle.r) }
Expand Down
Loading

0 comments on commit 838f832

Please sign in to comment.