Skip to content

Commit

Permalink
Minor optimizations for ZPure's runloop (#1474)
Browse files Browse the repository at this point in the history
  • Loading branch information
kyri-petrou authored Feb 10, 2025
1 parent 8b1e7b4 commit 90ea08c
Showing 1 changed file with 52 additions and 81 deletions.
133 changes: 52 additions & 81 deletions core/shared/src/main/scala/zio/prelude/fx/ZPure.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import zio.prelude._
import zio.prelude.coherent.CovariantIdentityBoth

import java.util.concurrent.atomic.AtomicBoolean
import scala.annotation.nowarn
import scala.reflect.ClassTag
import scala.util.Try

Expand Down Expand Up @@ -143,7 +144,7 @@ sealed trait ZPure[+W, -S1, +S2, -R, +E, +A] { self =>
* logs written in a failed computation will be cleared.
*/
final def clearLogOnError: ZPure[W, S1, S2, R, E, A] =
Flag(FlagType.ClearLogOnError, value = true, self)
ClearLogOnError(value = true, self)

/**
* Transforms the result of this computation with the specified partial
Expand Down Expand Up @@ -274,7 +275,7 @@ sealed trait ZPure[+W, -S1, +S2, -R, +E, +A] { self =>
* logs written in a failed computation will be kept (this is the default behavior).
*/
final def keepLogOnError: ZPure[W, S1, S2, R, E, A] =
Flag(FlagType.ClearLogOnError, value = false, self)
ClearLogOnError(value = false, self)

/**
* Returns a successful computation if the value is `Left`, or fails with error `None`.
Expand Down Expand Up @@ -1123,17 +1124,11 @@ object ZPure {
private final case class Provide[W, S1, S2, R, E, A](r: ZEnvironment[R], continue: ZPure[W, S1, S2, R, E, A])
extends ZPure[W, S1, S2, Any, E, A]
private final case class Log[S, +W](log: W) extends ZPure[W, S, S, Any, Nothing, Unit]
private final case class Flag[W, S1, S2, R, E, A](
flag: FlagType,
private final case class ClearLogOnError[W, S1, S2, R, E, A](
value: Boolean,
continue: ZPure[W, S1, S2, R, E, A]
) extends ZPure[W, S1, S2, R, E, A]

sealed trait FlagType
object FlagType {
case object ClearLogOnError extends FlagType
}

private object Runner {
private[this] val pool = new ThreadLocal[(Runner, AtomicBoolean)] {
override def initialValue(): (Runner, AtomicBoolean) = (new Runner(), new AtomicBoolean(false))
Expand All @@ -1156,10 +1151,6 @@ object ZPure {
new Runner().run(state, zPure)
}
}

final private case class Err(cause: Any) extends Exception {
override def fillInStackTrace(): Throwable = this
}
}

final private class Runner private {
Expand All @@ -1182,148 +1173,128 @@ object ZPure {
state: S1,
zPure: ZPure[W, S1, S2, R, E, A]
): (Chunk[W], Either[E, (S2, A)]) = {
val result =
try
Right(loop(state, zPure.asInstanceOf[Erased]))
catch {
case Runner.Err(c) => Left(c.asInstanceOf[E])
}
val result = loop(state, zPure.asInstanceOf[Erased])

(_logs.result().asInstanceOf[Chunk[W]], result)
}

private def loop[S2, A](state: Any, zPure: Erased) = {
@nowarn("msg=type argument")
private def loop[S2, E, A](state: Any, zPure: Erased): Either[E, (S2, A)] = {
// NOTE: Be careful not to add these into a lambda to avoid them being moved to the heap
var s0: Any = state
var a: Any = null
var curZPure = zPure

while (curZPure ne null)
curZPure match {
case flatmap0: ZPure.FlatMap[_, _, _, _, _, _, _, _] =>
val zPure = flatmap0.asInstanceOf[FlatMap[Any, Any, Any, Any, Any, Any, Any, Any]]
val nested = zPure.value
val continuation = zPure.continue
case flatmap0: ZPure.FlatMap[Any, Any, Any, Any, Any, Any, Any, Any] =>
val nested = flatmap0.value
val continuation = flatmap0.continue

nested match {
case succeed0: Succeed[_] =>
val zPure2 = succeed0.asInstanceOf[Succeed[Any]]
curZPure = continuation(zPure2.value)
case succeed0: Succeed[Any] =>
curZPure = continuation(succeed0.value)

case modify0: Modify[_, _, _] =>
val zPure2 = modify0.asInstanceOf[Modify[Any, Any, Any]]
val updated = zPure2.run0(s0)
case modify0: Modify[Any, Any, Any] =>
val updated = modify0.run0(s0)
s0 = updated._2
curZPure = continuation(updated._1)

case log0: Log[_, _] =>
val zPure = log0.asInstanceOf[Log[Any, Any]]
_logs addOne zPure.log
case log0: Log[Any, Any] =>
_logs addOne log0.log
curZPure = continuation(())

case environment0: Environment[_, _, _, _, _, _] =>
val zPure = environment0.asInstanceOf[Environment[Any, Any, Any, Any, Any, Any]]
curZPure = continuation(zPure.access(_environment))
case environment0: Environment[Any, Any, Any, Any, Any, Any] =>
curZPure = continuation(environment0.access(_environment))

case _ =>
curZPure = nested
stack.push(continuation)
}

case succeed0: Succeed[_] =>
val zPure = succeed0.asInstanceOf[Succeed[Any]]
a = zPure.value
case succeed0: Succeed[Any] =>
a = succeed0.value
val nextInstr = stack.pop()
if (nextInstr eq null) curZPure = null else curZPure = nextInstr(a)

case fold0: Fold[_, _, _, _, _, _, _, _, _] =>
val zPure = fold0.asInstanceOf[Fold[Any, Any, Any, Any, Any, Any, Any, Any, Any]]
case fold0: Fold[Any, Any, Any, Any, Any, Any, Any, Any, Any] =>
val state = s0
val clear = _clearLogOnError
val fold = if (clear) {
val previousLogs = _logs
_logs = ChunkBuilder.make()

ZPure.Fold(
zPure.value,
fold0.value,
(error: Any) => {
_logs = previousLogs
ZPure.set(state) *> zPure.failure(error)
ZPure.set(state) *> fold0.failure(error)
},
(a: Any) => {
val logs0 = _logs.result()
if (logs0.nonEmpty) previousLogs ++= logs0
if (!logs0.isEmpty) previousLogs ++= logs0
_logs = previousLogs
zPure.success(a)
fold0.success(a)
}
)
} else {
ZPure.Fold(
zPure.value,
ZPure.set(state) *> zPure.failure(_: Any),
zPure.success
fold0.value,
ZPure.set(state) *> fold0.failure(_: Any),
fold0.success
)
}

stack.push(fold)
curZPure = zPure.value
curZPure = fold0.value

case log0: Log[_, _] =>
val zPure = log0.asInstanceOf[Log[Any, Any]]
_logs addOne zPure.log
case log0: Log[Any, Any] =>
_logs addOne log0.log
val nextInstr = stack.pop()
a = ()
if (nextInstr eq null) curZPure = null else curZPure = nextInstr(a)
if (nextInstr eq null) curZPure = null else curZPure = nextInstr(())

case provide0: Provide[_, _, _, _, _, _] =>
val zPure = provide0.asInstanceOf[Provide[Any, Any, Any, Any, Any, Any]]
case provide0: Provide[Any, Any, Any, Any, Any, Any] =>
val previousEnv = _environment
_environment = zPure.r
curZPure = zPure.continue.foldM(
_environment = provide0.r
curZPure = provide0.continue.foldM(
e => { _environment = previousEnv; ZPure.fail(e) },
a => { _environment = previousEnv; ZPure.succeed(a) }
)

case environment0: Environment[_, _, _, _, _, _] =>
val zPure = environment0.asInstanceOf[Environment[Any, Any, Any, Any, Any, Any]]
a = zPure.access(_environment)
case environment0: Environment[Any, Any, Any, Any, Any, Any] =>
a = environment0.access(_environment)
val nextInstr = stack.pop()
if (nextInstr eq null) curZPure = null else curZPure = nextInstr(a)

case modify0: Modify[_, _, _] =>
val zPure = modify0.asInstanceOf[Modify[Any, Any, Any]]
val updated = zPure.run0(s0)
case modify0: Modify[Any, Any, Any] =>
val updated = modify0.run0(s0)
a = updated._1
s0 = updated._2
val nextInstr = stack.pop()
if (nextInstr eq null) curZPure = null else curZPure = nextInstr(a)

case flag0: Flag[_, _, _, _, _, _] =>
val zPure = flag0.asInstanceOf[Flag[Any, Any, Any, Any, Any, Any]]
zPure.flag match {
case FlagType.ClearLogOnError =>
val oldValue = _clearLogOnError
_clearLogOnError = zPure.value
val resetFn = (a: Any) => { _clearLogOnError = oldValue; a }
curZPure = zPure.continue.bimap(resetFn, resetFn)
}
case flag0: ClearLogOnError[Any, Any, Any, Any, Any, Any] =>
val oldValue = _clearLogOnError
_clearLogOnError = flag0.value
val resetFn = (a: Any) => { _clearLogOnError = oldValue; a }
curZPure = flag0.continue.bimap(resetFn, resetFn)

case fail0: Fail[_] =>
val zPure = fail0.asInstanceOf[Fail[Any]]
case fail0: Fail[Any] =>
curZPure = null
while (curZPure eq null)
stack.pop() match {
case null =>
throw Runner.Err(zPure.error)
case value: Fold[_, _, _, _, _, _, _, _, _] =>
val cont = value.failure.asInstanceOf[Continuation]
curZPure = cont(zPure.error)
case _ =>
case null =>
// No error handlers in the stack, exit evaluation with the error
return Left(fail0.error.asInstanceOf[E])
case value: Fold[Any, Any, Any, Any, Any, Any, Any, Any, Any] =>
curZPure = value.failure(fail0.error)
case _ =>
()
}
}

(s0.asInstanceOf[S2], a.asInstanceOf[A])
Right((s0.asInstanceOf[S2], a.asInstanceOf[A]))
}
}
}

0 comments on commit 90ea08c

Please sign in to comment.