Skip to content

Commit

Permalink
dev: Add better compile-time diagnostics
Browse files Browse the repository at this point in the history
  • Loading branch information
Iltotore committed Jun 2, 2024
1 parent 9c114fe commit c609578
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 6 deletions.
105 changes: 105 additions & 0 deletions main/src/io/github/iltotore/iron/macros/ReflectUtil.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
package io.github.iltotore.iron.macros

import io.github.iltotore.iron.compileTime.NumConstant

import scala.quoted.*

final class Foo

transparent inline def foo: Foo = new Foo

transparent inline def reflectUtil(using inline q: Quotes): ReflectUtil = new ReflectUtil

class ReflectUtil(using val _quotes: Quotes):

import quotes.reflect.*

enum DecodingFailure:
case Unknown
case NotInlined(term: Term)
case HasBindings(bindings: List[Definition])
case HasStatements(statements: List[Statement])
case ApplyNotInlined()

case OrNotInlined(left: Either[DecodingFailure, Boolean], right: Either[DecodingFailure, Boolean])
case AndNotInlined(left: Either[DecodingFailure, Boolean], right: Either[DecodingFailure, Boolean])
case StringPartsNotInlined(parts: List[Either[DecodingFailure, String]])

trait ExprDecoder[T]:

def decode(expr: Expr[T]): Either[DecodingFailure, T]

object ExprDecoder:

given [T](using fromExpr: FromExpr[T]): ExprDecoder[T] with

override def decode(expr: Expr[T]): Either[DecodingFailure, T] =
fromExpr.unapply(expr).toRight(DecodingFailure.Unknown)

private class PrimitiveExprDecoder[T <: NumConstant | Byte | Short | Boolean | String] extends ExprDecoder[T]:

def decodeTerm(tree: Term): Either[DecodingFailure, T] =
tree match
case Block(stats, e) => if stats.isEmpty then decodeTerm(e) else Left(DecodingFailure.HasStatements(stats))
case Inlined(_, bindings, e) => if bindings.isEmpty then decodeTerm(e) else Left(DecodingFailure.HasBindings(bindings))
case Typed(e, _) => decodeTerm(e)
case Apply(left, operands) => ???
case ref: Ref => Left(DecodingFailure.NotInlined(ref))
case _ =>
tree.tpe.widenTermRefByName match
case ConstantType(c) => Right(c.value.asInstanceOf[T])
case _ => Left(DecodingFailure.Unknown)

override def decode(expr: Expr[T]): Either[DecodingFailure, T] =
decodeTerm(expr.asTerm)

given [T <: NumConstant]: ExprDecoder[T] = new PrimitiveExprDecoder[T]

/**
* A ExprDecoder[Boolean] that can extract value from partially inlined || and
* && operations.
*
* {{{
* inline val x = true
* val y: Boolean = ???
*
* x || y //inlined to `true`
* y || x //inlined to `true`
*
* inline val a = false
* val b: Boolean = ???
*
* a && b //inlined to `false`
* b && a //inlined to `false`
* }}}
*/
given ExprDecoder[Boolean] = new PrimitiveExprDecoder[Boolean]:

override def decodeTerm(tree: Term): Either[DecodingFailure, Boolean] =
tree match
case Apply(Select(left, "||"), List(right)) if left.tpe <:< TypeRepr.of[Boolean] && right.tpe <:< TypeRepr.of[Boolean] => // OR
(decodeTerm(left), decodeTerm(right)) match
case (Right(true), _) => Right(true)
case (_, Right(true)) => Right(true)
case (Right(leftValue), Right(rightValue)) => Right(leftValue || rightValue)
case (leftResult, rightResult) => Left(DecodingFailure.OrNotInlined(leftResult, rightResult))

case Apply(Select(left, "&&"), List(right)) if left.tpe <:< TypeRepr.of[Boolean] && right.tpe <:< TypeRepr.of[Boolean] => // AND
(decodeTerm(left), decodeTerm(right)) match
case (Right(false), _) => Right(false)
case (_, Right(false)) => Right(false)
case (Right(leftValue), Right(rightValue)) => Right(leftValue && rightValue)
case (leftResult, rightResult) => Left(DecodingFailure.AndNotInlined(leftResult, rightResult))

case _ => super.decodeTerm(tree)

given ExprDecoder[String] = new PrimitiveExprDecoder[String]:

override def decodeTerm(tree: Term): Either[DecodingFailure, String] =
tree match
case Apply(Select(left, "+"), List(right)) if left.tpe <:< TypeRepr.of[String] && right.tpe <:< TypeRepr.of[String] =>
(decodeTerm(left), decodeTerm(right)) match
case (Right(leftValue), Right(rightValue)) => Right(leftValue + rightValue)
case (leftResult, rightResult) => Left(DecodingFailure.StringPartsNotInlined(List(leftResult, rightResult)))

case _ => super.decodeTerm(tree)
19 changes: 13 additions & 6 deletions main/src/io/github/iltotore/iron/macros/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -108,21 +108,28 @@ private def assertConditionImpl[A: Type](input: Expr[A], cond: Expr[Boolean], me

import quotes.reflect.*

val reflectUtil = new ReflectUtil

val inputType = TypeRepr.of[A]

val messageValue = message.value.getOrElse("<Unknown message>")
val condValue = cond.value
.getOrElse(
compileTimeError(
val messageDecoder = summon[reflectUtil.ExprDecoder[String]]
val condDecoder = summon[reflectUtil.ExprDecoder[Boolean]]

val messageValue = messageDecoder.decode(message).getOrElse("<Unknown message>")
val condValue = condDecoder.decode(cond)
.fold(
err => compileTimeError(
s"""Cannot refine value at compile-time because the predicate cannot be evaluated.
|This is likely because the condition or the input value isn't fully inlined.
|
|To test a constraint at runtime, use one of the `refine...` extension methods.
|
|${MAGENTA}Inlined input$RESET: ${input.show}
|${MAGENTA}Inlined condition$RESET: ${cond.show}
|${MAGENTA}Message$RESET: $messageValue""".stripMargin
)
|${MAGENTA}Message$RESET: $messageValue
|$err""".stripMargin
),
identity
)

if !condValue then
Expand Down

0 comments on commit c609578

Please sign in to comment.