Skip to content

Commit

Permalink
Merge pull request #3 from a-khakimov/RecursiveCallLinter
Browse files Browse the repository at this point in the history
Add NoRecursion Rule
  • Loading branch information
a-khakimov authored Dec 8, 2024
2 parents 388fd8b + 2d67db7 commit e9c9830
Show file tree
Hide file tree
Showing 12 changed files with 290 additions and 17 deletions.
11 changes: 3 additions & 8 deletions build.sbt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
lazy val V = _root_.scalafix.sbt.BuildInfo

lazy val rulesCrossVersions = Seq(V.scala213, V.scala212)
lazy val rulesCrossVersions = Seq(V.scala213)
lazy val scala3Version = "3.3.0"

ThisBuild / sonatypeCredentialHost := "s01.oss.sonatype.org"
Expand All @@ -21,7 +21,7 @@ inThisBuild(
),
semanticdbEnabled := true,
semanticdbVersion := scalafixSemanticdb.revision,
crossScalaVersions := List(V.scala213, V.scala212)
crossScalaVersions := List(V.scala213)
)
)

Expand Down Expand Up @@ -86,7 +86,7 @@ lazy val tests = projectMatrix
rulesCrossVersions.map(VirtualAxis.scalaABIVersion) :+ VirtualAxis.jvm: _*
)
.jvmPlatform(
scalaVersions = Seq(V.scala212),
scalaVersions = Seq(V.scala213),
axisValues = Seq(TargetAxis(scala3Version)),
settings = Seq()
)
Expand All @@ -95,10 +95,5 @@ lazy val tests = projectMatrix
axisValues = Seq(TargetAxis(V.scala213)),
settings = Seq()
)
.jvmPlatform(
scalaVersions = Seq(V.scala212),
axisValues = Seq(TargetAxis(V.scala212)),
settings = Seq()
)
.dependsOn(rules)
.enablePlugins(ScalafixTestkitPlugin)
21 changes: 20 additions & 1 deletion docs/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ error: [NoGeneralException] Exception is not allowed

This rule requires specifying arguments names when calling a function. It makes the code more readable.


```scala
function("Word", 42, 73.1)
/*
Expand Down Expand Up @@ -100,6 +99,26 @@ map(1)

---

### NoRecursion

This rule prohibits the use of recursive calls

```scala
object Main {
def foo(): Unit = bar()
def bar(): Unit = baz()
def baz(): Unit = foo()
}

/*
[error] main.scala:4:21: error: [NoRecursion.DetectedRecursion] Recursion detected: Main.bar() -> Main.baz() -> Main.foo()
[error] def foo(): Unit = bar()
[error] ^^^^^
*/
```

---

## Rewrite rules

### MakeArgsNamed (experimental!)
Expand Down

This file was deleted.

15 changes: 15 additions & 0 deletions input/src/main/scala/fix/RecursiveCall0.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package fix

/*
rule = NoRecursion
*/

class RecursiveCall0 {

def foo(): Unit = baz()
def bar(): Unit = foo()
def baz(): Unit = bar() /* assert: NoRecursion.DetectedRecursion
^^^^^
Recursion detected: RecursiveCall0.bar() -> RecursiveCall0.foo() -> RecursiveCall0.baz()
*/
}
22 changes: 22 additions & 0 deletions input/src/main/scala/fix/RecursiveCall1.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package fix

/*
rule = NoRecursion
*/

class RecursiveCall1 {

trait Ooo {
def ooo(s: String)(implicit i: Int): Unit
}

class OooImpl extends Ooo {
override def ooo(s: String)(implicit i: Int): Unit = qqq(s)

def qqq(a: String)(implicit i: Int): Unit = www(a) /* assert: NoRecursion.DetectedRecursion
^^^^^^
Recursion detected: OooImpl.www(String,Int) -> OooImpl.qqq(String,Int)
*/
def www(b: String)(implicit i: Int): Unit = qqq(b)
}
}
13 changes: 13 additions & 0 deletions input/src/main/scala/fix/RecursiveCall2.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package fix

/*
rule = NoRecursion
*/

class RecursiveCall2 {

def mmm(): Unit = mmm() /* assert: NoRecursion.DetectedRecursion
^^^^^
Recursion detected: RecursiveCall2.mmm()
*/
}
19 changes: 19 additions & 0 deletions input/src/main/scala/fix/RecursiveCall3.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package fix

/*
rule = NoRecursion
*/

class RecursiveCall3 {

trait Ggg {
def ggg: Unit
}

class GggImpl extends Ggg {
override def ggg: Unit = eee()

def eee(): Unit = println("Hello, eee")
def foo(): Unit = println("Hello, foo")
}
}
26 changes: 26 additions & 0 deletions input/src/main/scala/fix/RecursiveCall4.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package fix

import scala.annotation.tailrec

/*
rule = NoRecursion
*/

class RecursiveCall4 {

def calc = {
def factorial(n: Int): Int = {
@tailrec def factorialAcc(acc: Int, n: Int): Int = { /* assert: NoRecursion.DetectedRecursion
^
Recursion detected: RecursiveCall4.calc().factorialAcc(Int,Int)
*/
if (n <= 1) acc
else factorialAcc(n * acc, n - 1)
}

factorialAcc(1, n)
}

factorial(1)
}
}
20 changes: 20 additions & 0 deletions input/src/main/scala/fix/RecursiveCall5.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package fix

/*
rule = NoRecursion
*/

class RecursiveCall5 {

trait Ooo {
def ooo(): Unit
}

trait Vvv {
def ooo(): Unit
}

class OooImpl(vvv: Vvv) extends Ooo {
override def ooo(): Unit = vvv.ooo()
}
}
19 changes: 19 additions & 0 deletions input/src/main/scala/fix/RecursiveCall6.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package fix

/*
rule = NoRecursion
*/

class RecursiveCall6 {

trait Handler[T] {
def handle(): Unit = println("foo")
}

class FooHandler extends Handler[FooHandler] {
override def handle(): Unit = {
super.handle()
println("bar")
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ fix.NoHead
fix.NoMapApply
fix.NoOptionGet
fix.NoUnnamedArgs
fix.NoRecursion
132 changes: 132 additions & 0 deletions rules/src/main/scala/fix/NoRecursion.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
package fix

import scalafix.v1._

import scala.collection.mutable
import scala.meta._

class NoRecursion extends SemanticRule("NoRecursion") {

override def description: String = "Find recursion"

trait Pos {
def pos: Position
}

case class Meth(name: String)

object Meth {
def make(name: String, position: Position): Meth with Pos = new Meth(name) with Pos {
override def pos: Position = position
override def toString: String = name
}
}

override def fix(implicit doc: SemanticDocument): Patch = {

// fixme: Use immutable Map
val callGraph = mutable.Map.empty[Meth with Pos, mutable.Set[Meth with Pos]]

doc.tree.collect {
case _ @ Defn.Object(_, name, templ) =>
collectMethods(name.value, templ, callGraph)
case _ @ Defn.Class.After_4_6_0(_, name, _, _, templ) =>
collectMethods(name.value, templ, callGraph)
}

val cycles = findCycle(callGraph.toMap.map { case (k, v) => k -> v.toSet })

cycles.map { cycle =>
Patch.lint(
Diagnostic(
"DetectedRecursion", s"Recursion detected: ${cycle.map(_.name).mkString(" -> ")}",
cycle.headOption.map(_.pos).getOrElse(doc.tree.pos)
)
)
}.asPatch
}

private def collectMethods(
qualifier: String,
templ: Template,
callGraph: mutable.Map[Meth with Pos, mutable.Set[Meth with Pos]]
)(implicit doc: SemanticDocument): Unit = {
templ.stats.foreach {
case defn @ Defn.Def.After_4_6_0(_, name, _, _, body) =>
val params = getParameterTypes(defn.symbol)
val fullName = s"$qualifier.${name.value}(${params.mkString(",")})"
val calls = collectFunctionCalls(body).map(
c => if (c.name.contains(".")) c else Meth.make(s"$qualifier.$c", defn.pos) // fixme: This is 🩼
)
callGraph.getOrElseUpdate(Meth.make(fullName, defn.pos), mutable.Set.empty) ++= calls
collectMethodsWithNestedFunctions(fullName, body, callGraph)
case _ =>
}
}

private def getParameterTypes(symbol: Symbol)(implicit doc: SemanticDocument): List[String] = {
symbol.info.map(_.signature) match {
case Some(MethodSignature(_, parameters, _)) => parameters.flatten.map(_.signature.toString())
case _ => Nil
}
}

private def collectFunctionCalls(body: Term)(implicit doc: SemanticDocument): Set[Meth with Pos] = {
body.collect {
case defn @ Term.Apply.After_4_6_0(Term.Select(Term.This(_), Term.Name(methodName)), _) =>
val className = defn.symbol.owner.info.map(_.displayName + ".").getOrElse("")
val params = getParameterTypes(defn.symbol)
val method = Meth.make(s"${className}$methodName(${params.mkString(",")})", defn.pos)
Some(method)
case defn @ Term.Apply.After_4_6_0(Term.Name(methodName), _) =>
val className = defn.symbol.owner.info.map(_.displayName + ".").getOrElse("")
val params = getParameterTypes(defn.symbol)
val method = Meth.make(s"${className}$methodName(${params.mkString(",")})", defn.pos)
Some(method)
case defn @ Term.Apply.After_4_6_0(Term.Select(receiver, Term.Name(methodName)), _) if !receiver.is[Term.This] && !receiver.is[Term.Name] =>
val className = defn.symbol.owner.info.map(_.displayName + ".").getOrElse("")
val params = getParameterTypes(defn.symbol)
val method = Meth.make(s"${className}$methodName(${params.mkString(",")})", defn.pos)
Some(method)
// Ignore super.methodName
case _ @ Term.Apply.After_4_6_0(Term.Select(Term.Super(_, _), Term.Name(_)), _) =>
None
}.toSet
}.flatten

private def collectMethodsWithNestedFunctions(
qualifier: String,
body: Term,
callGraph: mutable.Map[Meth with Pos, mutable.Set[Meth with Pos]]
)(implicit doc: SemanticDocument): Unit = {
body.collect {
case defn @ Defn.Def.After_4_6_0(_, name, _, _, innerBody) =>
val params = getParameterTypes(defn.symbol)
val fullName = s"$qualifier.${name.value}(${params.mkString(",")})"
val calls = collectFunctionCalls(innerBody).map {
case call if call.name.contains(".") => call
case call => Meth.make(s"$qualifier.$call", defn.pos) // fixme: This is 🩼
}
callGraph.getOrElseUpdate(Meth.make(fullName, defn.pos), mutable.Set.empty) ++= calls
collectMethodsWithNestedFunctions(fullName, innerBody, callGraph)
}
}

def findCycle(graph: Map[Meth with Pos, Set[Meth with Pos]]): Option[List[Meth with Pos]] = {
def dfs(
node: Meth with Pos,
visited: Set[Meth with Pos],
inStack: Set[Meth with Pos],
path: List[Meth with Pos]
): Option[List[Meth with Pos]] = {
if (inStack.contains(node)) Some((node :: path.takeWhile(_ != node)).reverse)
else if (visited.contains(node)) None
else graph.getOrElse(node, Set.empty).toList.foldLeft(Option.empty[List[Meth with Pos]]) {
case (cycle @ Some(_), _) => cycle
case (None, neighbor) => dfs(neighbor, visited + node, inStack + node, node :: path)
}
}

graph.keys.flatMap(node => dfs(node, Set.empty, Set.empty, Nil)).headOption
}
}

0 comments on commit e9c9830

Please sign in to comment.