Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add OpIntercept to allow custom types and operations #149

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions src/main/scala/singleton/ops/OpIntercept.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package singleton.ops
import scala.reflect.macros.whitebox
import impl._

import scala.annotation.implicitNotFound

trait OpIntercept[Op <: HasOut]
object OpIntercept {
@implicitNotFound("Failed to cache with result ${Out}")
trait CacheResult[Out]
object CacheResult {
implicit def call[Out] : CacheResult[Out] = macro Macro.materializeCacheResult[Out]
final class Macro(val c: whitebox.Context) extends GeneralMacros {
def materializeCacheResult[
Out: c.WeakTypeTag
]: c.Tree = cacheOpInterceptResult[Out]
}

}
}
115 changes: 86 additions & 29 deletions src/main/scala/singleton/ops/impl/GeneralMacros.scala
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
package singleton.ops.impl
import singleton.twoface.impl.TwoFaceAny

import scala.reflect.macros.whitebox
import scala.reflect.macros.{TypecheckException, whitebox}

private object MacroCache {
import scala.collection.mutable
val cache = mutable.Map.empty[Any, Any]
def get(key : Any) : Option[Any] = cache.get(key)
def add[V <: Any](key : Any, value : V) : V = {cache += (key -> value); value}
private var opInterceptValue : Option[Any] = None
def setOpInterceptValue(value : Any) : Unit = opInterceptValue = Some(value)
def getOpInterceptValue : Any = opInterceptValue
def clearOpInterceptValue() : Unit = opInterceptValue = None
}
trait GeneralMacros {
val c: whitebox.Context
Expand Down Expand Up @@ -261,7 +265,7 @@ trait GeneralMacros {

def unapply(arg: CalcType): Option[Primitive] = Some(arg.primitive)
}
case class CalcUnknown(tpe : Type, treeOption : Option[Tree]) extends Calc {
case class CalcUnknown(tpe : Type, treeOption : Option[Tree], opIntercept : Boolean) extends Calc {
override val primitive: Primitive = Primitive.Unknown(tpe, "Unknown")
}
object NonLiteralCalc {
Expand Down Expand Up @@ -328,6 +332,19 @@ trait GeneralMacros {
VerboseTraversal(s"${GREEN}${BOLD}caching${RESET} $k -> $value")
value
}
def setOpInterceptCalc(calc : Calc) : Unit = MacroCache.setOpInterceptValue(Left(calc))
def setOpInterceptError(msg : String) : Unit = MacroCache.setOpInterceptValue(Right(msg))
def clearOpInterceptCalc() : Unit = MacroCache.clearOpInterceptValue()
def getOpInterceptCalc : Option[Either[Calc, String]] = {
MacroCache.getOpInterceptValue.asInstanceOf[Option[Either[Calc, String]]] match {
case Some(Left(v)) => Some(Left(v match {
case lit : CalcLit => CalcLit(lit.value) //reconstruct internal literal tree
case nlit : CalcNLit => CalcNLit(nlit.primitive, deepCopyTree(nlit.tree))
case c => c
}))
case v => v
}
}
}
////////////////////////////////////////////////////////////////////

Expand Down Expand Up @@ -456,7 +473,7 @@ trait GeneralMacros {
def unapply(tp: Type): Option[Calc] = {
tp match {
case TypeRef(_, sym, ft :: tp :: _) if sym == opMacroSym && ft.typeSymbol == funcTypes.GetType =>
Some(CalcUnknown(tp, None))
Some(CalcUnknown(tp, None, opIntercept = false))
case TypeRef(_, sym, args) if sym == opMacroSym =>
VerboseTraversal(s"@@OpCalc@@\nTP: $tp\nRAW: ${showRaw(tp)}")
val funcType = args.head.typeSymbol.asType
Expand All @@ -473,7 +490,7 @@ trait GeneralMacros {
case (funcTypes.ImplicitFound, _) =>
setUncachingReason(1)
aValue match {
case CalcUnknown(t, _) => try {
case CalcUnknown(t, _, false) => try {
c.typecheck(q"implicitly[$t]")
Some(CalcLit(true))
} catch {
Expand All @@ -484,7 +501,7 @@ trait GeneralMacros {
}
case (funcTypes.EnumCount, _) =>
aValue match {
case CalcUnknown(t, _) => Some(CalcLit(t.typeSymbol.asClass.knownDirectSubclasses.size))
case CalcUnknown(t, _, false) => Some(CalcLit(t.typeSymbol.asClass.knownDirectSubclasses.size))
case _ => Some(CalcLit(0))
}
case (funcTypes.IsNat, _) =>
Expand Down Expand Up @@ -549,9 +566,10 @@ trait GeneralMacros {
}

case _ => //regular cases
opCalc(funcType, aValue, bValue, cValue) match {
opCalc(Some(tp), funcType, aValue, bValue, cValue) match {
case (res : CalcVal) => Some(res)
case u @ CalcUnknown(_,Some(_)) => Some(u) //Accept unknown values with a tree
case u @ CalcUnknown(_,Some(_), _) => Some(u) //Accept unknown values with a tree
case oi @ CalcUnknown(_,_, true) => Some(oi) //Accept unknown op interception
case _ => None
}
}
Expand All @@ -575,7 +593,7 @@ trait GeneralMacros {
case Some(t : CalcUnknown) => t
case _ =>
VerboseTraversal(s"@@Unknown@@\nTP: $tp\nRAW: ${showRaw(tp)}")
CalcUnknown(tp, None)
CalcUnknown(tp, None, opIntercept = false)
}
}

Expand Down Expand Up @@ -654,10 +672,11 @@ trait GeneralMacros {
}
////////////////////////////////////////////////////////////////////////

def abort(msg: String, annotatedSym : Option[TypeSymbol] = defaultAnnotatedSym): Nothing = {
def abort(msg: String, annotatedSym : Option[TypeSymbol] = defaultAnnotatedSym, position : Position = c.enclosingPosition): Nothing = {
VerboseTraversal(s"!!!!!!aborted with: $msg at $annotatedSym, $defaultAnnotatedSym")
if (annotatedSym.isDefined) setAnnotation(msg, annotatedSym.get)
c.abort(c.enclosingPosition, msg)
CalcCache.setOpInterceptError(msg) //propagating the error in case this is an inner implicit call for OpIntercept
c.abort(position, msg)
}

def buildWarningMsgLoc : String = s"${c.enclosingPosition.source.path}:${c.enclosingPosition.line}:${c.enclosingPosition.column}"
Expand Down Expand Up @@ -734,11 +753,11 @@ trait GeneralMacros {
case None =>
q"""
new $opTpe {
type OutWide = Option[$outTpe]
type Out = Option[$outTpe]
final val value: Option[$outTpe] = None
type OutWide = $outTpe
type Out = $outTpe
final lazy val value: $outTpe = throw new IllegalArgumentException("This operation does not produce a value.")
final val isLiteral = false
final val valueWide: Option[$outTpe] = None
final lazy val valueWide: $outTpe = throw new IllegalArgumentException("This operation does not produce a value.")
}
"""
}
Expand Down Expand Up @@ -771,11 +790,11 @@ trait GeneralMacros {
}

opTree match {
case q"""{
$mods class $tpname[..$tparams] $ctorMods(...$paramss) extends ..$parents { $self => ..$opClsBlk }
$expr(...$exprss)
}""" => getOut(opClsBlk)
case _ => extractionFailed(opTree)
case q"""{
$mods class $tpname[..$tparams] $ctorMods(...$paramss) extends ..$parents { $self => ..$opClsBlk }
$expr(...$exprss)
}""" => getOut(opClsBlk)
case _ => extractionFailed(opTree)
}
}

Expand Down Expand Up @@ -881,27 +900,64 @@ trait GeneralMacros {
val (typedTree, tpe) = GetArgTree(argIdx, lhs)
VerboseTraversal(s"@@extractFromArg@@\nTP: $tpe\nRAW: ${showRaw(tpe)}\nTree: $typedTree")
TypeCalc(tpe) match {
case _ : CalcUnknown => CalcUnknown(tpe, Some(c.untypecheck(typedTree)))
case _ : CalcUnknown => CalcUnknown(tpe, Some(c.untypecheck(typedTree)), opIntercept = false)
case t : CalcNLit => CalcNLit(t, typedTree)
case t => t
}
}
///////////////////////////////////////////////////////////////////////////////////////////

///////////////////////////////////////////////////////////////////////////////////////////
// OpInterept Result Caching
///////////////////////////////////////////////////////////////////////////////////////////
def cacheOpInterceptResult[Out](implicit ev0: c.WeakTypeTag[Out]) : Tree = {
val outTpe = weakTypeOf[Out]
val outCalc = TypeCalc(outTpe)
CalcCache.setOpInterceptCalc(outCalc)
q"new _root_.singleton.ops.OpIntercept.CacheResult[$outTpe]{}"
}
///////////////////////////////////////////////////////////////////////////////////////////

///////////////////////////////////////////////////////////////////////////////////////////
// Three operands (Generic)
///////////////////////////////////////////////////////////////////////////////////////////
def materializeOpGen[F](implicit ev0: c.WeakTypeTag[F]): MaterializeOpAuxGen =
new MaterializeOpAuxGen(weakTypeOf[F])

def opCalc(funcType : TypeSymbol, aCalc : => Calc, bCalc : => Calc, cCalc : => Calc) : Calc = {
def opCalc(opTpe : Option[Type], funcType : TypeSymbol, aCalc : => Calc, bCalc : => Calc, cCalc : => Calc) : Calc = {
lazy val a = aCalc
lazy val b = bCalc
lazy val cArg = cCalc
def unsupported() : Calc = {
(a, b) match {
case (aArg : CalcVal, bArg : CalcVal) => abort(s"Unsupported $funcType[$a, $b, $cArg]")
case _ => CalcUnknown(funcType.toType, None)
val cachedTpe = opTpe.get match {
case TypeRef(pre, sym, args) => c.internal.typeRef(pre, sym, List(funcType.toType, a.tpe, b.tpe, cArg.tpe))
}
//calling OpIntercept for the operation should cache the expected result if executed correctly
CalcCache.clearOpInterceptCalc()
val implicitlyTree = q"implicitly[_root_.singleton.ops.OpIntercept[$cachedTpe]]"
try {
c.typecheck(implicitlyTree, silent = false)
val cachedCalc = CalcCache.getOpInterceptCalc match {
case Some(calc) => calc
case None => abort("Missing a result cache for OpIntercept. Make sure you set `OpIntercept.CacheResult`")
}
CalcCache.clearOpInterceptCalc()
cachedCalc match {
case Left(t : CalcUnknown) =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is totally nitpicking, but IIUC, the convention for Either is that the Left is the "error" condition, isn't it?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right. I've rarely used it in this fashion, so forgot. I'll fix.

t.copy(opIntercept = true) //the unknown result must be marked properly so we allow it later
case Left(t) => t
case Right(msg) => abort(msg)
}
} catch {
case TypecheckException(pos, msg) =>
CalcCache.getOpInterceptCalc match {
case Some(Right(msg)) => abort(msg)
case _ =>
(a, b) match {
case (_ : CalcVal, _ : CalcVal) => abort(s"Unsupported operation $cachedTpe")
case _ => CalcUnknown(funcType.toType, None, opIntercept = false)
}
}
}
}

Expand Down Expand Up @@ -1055,7 +1111,7 @@ trait GeneralMacros {
}
//directly using the java lib `require` resulted in compiler crash, so we use wrapped require instead
case CalcNLit(Primitive.String, msg, _) => cArg match {
case CalcUnknown(t, _) if t.typeSymbol == symbolOf[Warn] =>
case CalcUnknown(t, _, false) if t.typeSymbol == symbolOf[Warn] =>
CalcNLit(Primitive.Boolean, q"""{println(${buildWarningMsg(msg)}); false}""")
case _ =>
CalcNLit(Primitive.Boolean, q"{_root_.singleton.ops.impl._require(false, $msg); false}")
Expand All @@ -1065,7 +1121,7 @@ trait GeneralMacros {
case CalcNLit(Primitive.Boolean, cond, _) => b match {
//directly using the java lib `require` resulted in compiler crash, so we use wrapped require instead
case CalcVal(msg : String, msgt) => cArg match {
case CalcUnknown(t, _) if t == symbolOf[Warn] =>
case CalcUnknown(t, _, false) if t == symbolOf[Warn] =>
CalcNLit(Primitive.Boolean,
q"""{
if ($cond) true
Expand Down Expand Up @@ -1366,7 +1422,7 @@ trait GeneralMacros {
case funcTypes.PrefixMatch => PrefixMatch
case funcTypes.ReplaceFirstMatch => ReplaceFirstMatch
case funcTypes.ReplaceAllMatches => ReplaceAllMatches
case _ => abort(s"Unsupported $funcType[$a, $b, $cArg]")
case _ => unsupported()
}
}

Expand All @@ -1381,6 +1437,7 @@ trait GeneralMacros {
else genOpTreeNat(opTpe, t)
case (_, CalcLit(_, t)) => genOpTreeLit(opTpe, t)
case (funcTypes.AcceptNonLiteral | funcTypes.GetArg, t : CalcNLit) => genOpTreeNLit(opTpe, t)
case (_, t @ CalcUnknown(_,_,true)) => genOpTreeUnknown(opTpe, t)
soronpo marked this conversation as resolved.
Show resolved Hide resolved
case (funcTypes.GetArg, t : CalcUnknown) => genOpTreeUnknown(opTpe, t)
case (_, t: CalcNLit) =>
abort("Calculation has returned a non-literal type/value.\nTo accept non-literal values, use `AcceptNonLiteral[T]`.")
Expand Down Expand Up @@ -1500,7 +1557,7 @@ trait GeneralMacros {
}
}

val reqCalc = opCalc(funcTypes.Require, condCalc, msgCalc, CalcUnknown(typeOf[NoSym], None))
val reqCalc = opCalc(None, funcTypes.Require, condCalc, msgCalc, CalcUnknown(typeOf[NoSym], None, opIntercept = false))

q"""
(new $chkSym[$condTpe, $msgTpe, $chkArgTpe]($outTree.asInstanceOf[$outTpe]))
Expand Down Expand Up @@ -1566,7 +1623,7 @@ trait GeneralMacros {
}
}

val reqCalc = opCalc(funcTypes.Require, condCalc, msgCalc, CalcUnknown(typeOf[NoSym], None))
val reqCalc = opCalc(None, funcTypes.Require, condCalc, msgCalc, CalcUnknown(typeOf[NoSym], None, opIntercept = false))

q"""
(new $chkSym[$condTpe, $msgTpe, $chkArgTpe, $paramFaceTpe, $paramTpe]($outTree.asInstanceOf[$outTpe]))
Expand Down
8 changes: 6 additions & 2 deletions src/main/scala/singleton/ops/impl/Op.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@ trait HasOut extends Any with Serializable {
type Out
}

trait Op extends HasOut {
trait HasOutValue extends HasOut {
val value : Out
}

trait Op extends HasOutValue {
type OutWide
type Out
type OutNat <: Nat
Expand All @@ -29,7 +33,7 @@ protected[singleton] object OpGen {
implicit def getValue[O <: Op, Out](o : Aux[O, Out]) : Out = o.value
}

trait OpCast[T, O <: Op] extends HasOut {type Out <: T; val value : Out}
trait OpCast[T, O <: Op] extends HasOutValue {type Out <: T}


@scala.annotation.implicitNotFound(msg = "Unable to prove type argument is a Nat.")
Expand Down
81 changes: 81 additions & 0 deletions src/test/scala/singleton/ops/OpInterceptSpec.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
package singleton.ops

import org.scalacheck.Properties
import shapeless.test.illTyped
import singleton.TestUtils._

class OpInterceptSpec extends Properties("OpInterceptSpec") {

trait Vec[A0, A1]

implicit def `Vec+`[VL0, VL1, VR0, VR1, VO0, VO1](
implicit
opL : OpAuxGen[VL0 + VR0, VO0],
opR : OpAuxGen[VL1 + VR1, VO1],
result : OpIntercept.CacheResult[Vec[VO0, VO1]]
) : OpIntercept[Vec[VL0, VL1] + Vec[VR0, VR1]] = ???

implicit def `Vec==`[VL0, VL1, VR0, VR1, EqOut](
implicit
op : OpAuxGen[(VL0 == VR0) && (VL1 == VR1), EqOut],
result : OpIntercept.CacheResult[EqOut]
) : OpIntercept[Vec[VL0, VL1] == Vec[VR0, VR1]] = ???


property("Custom Vec Equality OK") = wellTyped {
val eq1 = shapeless.the[Vec[W.`1`.T, W.`2`.T] == Vec[W.`1`.T, W.`2`.T]]
val eq2 = shapeless.the[Vec[W.`1`.T, W.`2`.T] == Vec[W.`1`.T, W.`1`.T]]
implicitly[eq1.Out =:= W.`true`.T]
implicitly[eq2.Out =:= W.`false`.T]
}

property("Custom Vec Addition OK") = wellTyped {
val add2 = shapeless.the[Vec[W.`1`.T, W.`2`.T] + Vec[W.`3`.T, W.`8`.T]]
val add3 = shapeless.the[Vec[W.`1`.T, W.`2`.T] + Vec[W.`3`.T, W.`8`.T] + Vec[W.`20`.T, W.`20`.T]]
implicitly[add2.Out =:= Vec[W.`4`.T, W.`10`.T]]
implicitly[add3.Out =:= Vec[W.`24`.T, W.`30`.T]]
val add23 = shapeless.the[add2.Out + add3.Out]
implicitly[add23.Out =:= Vec[W.`28`.T, W.`40`.T]]
}

trait FibId
type Fib[P] = impl.OpMacro[FibId, P, W.`0`.T, W.`0`.T]
implicit def doFib[P, Out](
implicit
op : OpAuxGen[ITE[P == W.`0`.T, W.`0`.T, ITE[P == W.`1`.T, W.`1`.T, Fib[P - W.`1`.T] + Fib[P - W.`2`.T]]], Out],
result : OpIntercept.CacheResult[Out]
) : OpIntercept[Fib[P]] = ???


property("Custom Fibonacci Op OK") = wellTyped {
val fib4 = shapeless.the[Fib[W.`4`.T]]
implicitly[fib4.Out =:= W.`3`.T]
val fib10 = shapeless.the[Fib[W.`10`.T]]
implicitly[fib10.Out =:= W.`55`.T]
}


trait FooOpId
type FooOp[C, M] = impl.OpMacro[FooOpId, C, M, W.`0`.T]
implicit def FooOp[C, M](
implicit
r : RequireMsg[C, M],
result : OpIntercept.CacheResult[W.`true`.T]
) : OpIntercept[FooOp[C, M]] = ???

property("Error Message Propagation") = wellTyped {
illTyped("""shapeless.the[FooOp[W.`false`.T, W.`"this is a test"`.T]]""", "this is a test")
}

trait BarOpId
type BarOp[C, M] = impl.OpMacro[BarOpId, C, M, W.`0`.T]
implicit def BarOp[C, M](
implicit
op : C + M
) : OpIntercept[BarOp[C, M]] = ???

property("Missing Caching Error") = wellTyped {
illTyped("""shapeless.the[BarOp[W.`1`.T, W.`2`.T]]""", "Missing a result cache for OpIntercept. Make sure you set `OpIntercept.CacheResult`")
}

}