diff --git a/.scalafmt.conf b/.scalafmt.conf index cccc72a9..3abbee71 100644 --- a/.scalafmt.conf +++ b/.scalafmt.conf @@ -1,10 +1,11 @@ -version = 3.8.1 +version = 3.8.2 runner.dialect = scala3 preset = IntelliJ maxColumn = 100 align.preset = true indent.main = 4 +indent.significant = 2 newlines.source = keep rewrite.scala3.convertToNewSyntax = true diff --git a/build.sbt b/build.sbt index e53c54e9..01028236 100644 --- a/build.sbt +++ b/build.sbt @@ -1,6 +1,6 @@ name := "chen" ThisBuild / organization := "io.appthreat" -ThisBuild / version := "2.1.2" +ThisBuild / version := "2.1.3" ThisBuild / scalaVersion := "3.4.2" val cpgVersion = "1.0.0" diff --git a/codemeta.json b/codemeta.json index 38db431a..482bd4d9 100644 --- a/codemeta.json +++ b/codemeta.json @@ -7,7 +7,7 @@ "downloadUrl": "https://github.com/AppThreat/chen", "issueTracker": "https://github.com/AppThreat/chen/issues", "name": "chen", - "version": "2.1.2", + "version": "2.1.3", "description": "Code Hierarchy Exploration Net (chen) is an advanced exploration toolkit for your application source code and its dependency hierarchy.", "applicationCategory": "code-analysis", "keywords": [ diff --git a/console/src/main/scala/io/appthreat/console/BridgeBase.scala b/console/src/main/scala/io/appthreat/console/BridgeBase.scala index 0be16159..3211cef0 100644 --- a/console/src/main/scala/io/appthreat/console/BridgeBase.scala +++ b/console/src/main/scala/io/appthreat/console/BridgeBase.scala @@ -44,306 +44,306 @@ case class Config( trait BridgeBase extends InteractiveShell with ScriptExecution with PluginHandling with ServerHandling: - def jProduct: JProduct - - protected def parseConfig(args: Array[String]): Config = - val parser = new scopt.OptionParser[Config](jProduct.name): - override def errorOnUnknownArgument = false - - note("Script execution") - - opt[Path]("script") - .action((x, c) => c.copy(scriptFile = Some(x))) - .text("path to script file: will execute and exit") - - opt[String]("param") - .valueName("param1=value1") - .unbounded() - .optional() - .action { (x, c) => - x.split("=", 2) match - case Array(key, value) => c.copy(params = c.params + (key -> value)) - case _ => - throw new IllegalArgumentException(s"unable to parse param input $x") - } - .text("key/value pair for main function in script - may be passed multiple times") - - opt[Path]("import") - .valueName("script1.sc") - .unbounded() - .optional() - .action((x, c) => c.copy(additionalImports = c.additionalImports :+ x)) - .text( - "import (and run) additional script(s) on startup - may be passed multiple times" - ) - - opt[String]('d', "dep") - .valueName("com.michaelpollmeier:versionsort:1.0.7") - .unbounded() - .optional() - .action((x, c) => c.copy(dependencies = c.dependencies :+ x)) - .text( - "add artifacts (including transitive dependencies) for given maven coordinate to classpath - may be passed multiple times" - ) - - opt[String]('r', "repo") - .valueName("https://repository.apache.org/content/groups/public/") - .unbounded() - .optional() - .action((x, c) => c.copy(resolvers = c.resolvers :+ x)) - .text( - "additional repositories to resolve dependencies - may be passed multiple times" - ) - - opt[String]("command") - .action((x, c) => c.copy(command = Some(x))) - .text("select one of multiple @main methods") - - note("Plugin Management") - - opt[String]("add-plugin") - .action((x, c) => c.copy(addPlugin = Some(x))) - .text("Plugin zip file to add to the installation") - - opt[String]("remove-plugin") - .action((x, c) => c.copy(rmPlugin = Some(x))) - .text("Name of plugin to remove from the installation") - - opt[Unit]("plugins") - .action((_, c) => c.copy(listPlugins = true)) - .text("List available plugins and layer creators") - - opt[String]("run") - .action((x, c) => c.copy(pluginToRun = Some(x))) - .text("Run layer creator. Get a list via --plugins") - - opt[String]("src") - .action((x, c) => c.copy(src = Some(x))) - .text("Source code directory to run layer creator on") - - opt[String]("language") - .action((x, c) => c.copy(language = Some(x))) - .text("Language to use in CPG creation") - - opt[Unit]("overwrite") - .action((_, c) => c.copy(overwrite = true)) - .text("Overwrite CPG if it already exists") - - opt[Unit]("store") - .action((_, c) => c.copy(store = true)) - .text("Store graph changes made by layer creator") - - note("REST server mode") - - opt[Unit]("server") - .action((_, c) => c.copy(server = true)) - .text("run as HTTP server") - - opt[String]("server-host") - .action((x, c) => c.copy(serverHost = x)) - .text("Hostname on which to expose the Chen server") - - opt[Int]("server-port") - .action((x, c) => c.copy(serverPort = x)) - .text("Port on which to expose the Chen server") - - opt[String]("server-auth-username") - .action((x, c) => c.copy(serverAuthUsername = Option(x))) - .text("Basic auth username for the Chen server") - - opt[String]("server-auth-password") - .action((x, c) => c.copy(serverAuthPassword = Option(x))) - .text("Basic auth password for the Chen server") - - note("Misc") - - arg[java.io.File]("") - .optional() - .action((x, c) => c.copy(cpgToLoad = Some(x.toScala))) - .text("Atom to load") - - opt[String]("for-input-path") - .action((x, c) => c.copy(forInputPath = Some(x))) - .text("Open CPG for given input path - overrides ") - - opt[Unit]("nocolors") - .action((_, c) => c.copy(nocolors = true)) - .text("turn off colors") - - opt[Unit]("verbose") - .action((_, c) => c.copy(verbose = true)) - .text("enable verbose output (predef, resolved dependency jars, ...)") - - opt[Int]("maxHeight") - .action((x, c) => c.copy(maxHeight = Some(x))) - .text( - "Maximum number lines to print before output gets truncated (default: no limit)" - ) - - help("help") - .text("Print this help text") - - // note: if config is really `None` an error message would have been displayed earlier - parser.parse(args, Config()).get - end parseConfig - - /** Entry point for Chen's integrated REPL and plugin manager */ - protected def run(config: Config): Unit = - if config.listPlugins then - printPluginsAndLayerCreators(config) - else if config.addPlugin.isDefined then - new PluginManager(InstallConfig().rootPath).add(config.addPlugin.get) - else if config.rmPlugin.isDefined then - new PluginManager(InstallConfig().rootPath).rm(config.rmPlugin.get) - else if config.scriptFile.isDefined then - val scriptReturn = runScript(config) - if scriptReturn.isFailure then - println(scriptReturn.failed.get.getMessage) - System.exit(1) - else if config.server then - GlobalReporting.enable() - startHttpServer(config) - else if config.pluginToRun.isDefined then - runPlugin(config, jProduct.name) - else - startInteractiveShell(config) - - protected def createPredefFile(additionalLines: Seq[String] = Nil): Path = - val tmpFile = Files.createTempFile("chen-predef", "sc") - Files.write(tmpFile, (predefLines ++ additionalLines).asJava) - tmpFile.toAbsolutePath - - /** code that is executed on startup */ - protected def predefLines: Seq[String] - - protected def greeting: String - - protected def promptStr: String - - protected def onExitCode: String + def jProduct: JProduct + + protected def parseConfig(args: Array[String]): Config = + val parser = new scopt.OptionParser[Config](jProduct.name): + override def errorOnUnknownArgument = false + + note("Script execution") + + opt[Path]("script") + .action((x, c) => c.copy(scriptFile = Some(x))) + .text("path to script file: will execute and exit") + + opt[String]("param") + .valueName("param1=value1") + .unbounded() + .optional() + .action { (x, c) => + x.split("=", 2) match + case Array(key, value) => c.copy(params = c.params + (key -> value)) + case _ => + throw new IllegalArgumentException(s"unable to parse param input $x") + } + .text("key/value pair for main function in script - may be passed multiple times") + + opt[Path]("import") + .valueName("script1.sc") + .unbounded() + .optional() + .action((x, c) => c.copy(additionalImports = c.additionalImports :+ x)) + .text( + "import (and run) additional script(s) on startup - may be passed multiple times" + ) + + opt[String]('d', "dep") + .valueName("com.michaelpollmeier:versionsort:1.0.7") + .unbounded() + .optional() + .action((x, c) => c.copy(dependencies = c.dependencies :+ x)) + .text( + "add artifacts (including transitive dependencies) for given maven coordinate to classpath - may be passed multiple times" + ) + + opt[String]('r', "repo") + .valueName("https://repository.apache.org/content/groups/public/") + .unbounded() + .optional() + .action((x, c) => c.copy(resolvers = c.resolvers :+ x)) + .text( + "additional repositories to resolve dependencies - may be passed multiple times" + ) + + opt[String]("command") + .action((x, c) => c.copy(command = Some(x))) + .text("select one of multiple @main methods") + + note("Plugin Management") + + opt[String]("add-plugin") + .action((x, c) => c.copy(addPlugin = Some(x))) + .text("Plugin zip file to add to the installation") + + opt[String]("remove-plugin") + .action((x, c) => c.copy(rmPlugin = Some(x))) + .text("Name of plugin to remove from the installation") + + opt[Unit]("plugins") + .action((_, c) => c.copy(listPlugins = true)) + .text("List available plugins and layer creators") + + opt[String]("run") + .action((x, c) => c.copy(pluginToRun = Some(x))) + .text("Run layer creator. Get a list via --plugins") + + opt[String]("src") + .action((x, c) => c.copy(src = Some(x))) + .text("Source code directory to run layer creator on") + + opt[String]("language") + .action((x, c) => c.copy(language = Some(x))) + .text("Language to use in CPG creation") + + opt[Unit]("overwrite") + .action((_, c) => c.copy(overwrite = true)) + .text("Overwrite CPG if it already exists") + + opt[Unit]("store") + .action((_, c) => c.copy(store = true)) + .text("Store graph changes made by layer creator") + + note("REST server mode") + + opt[Unit]("server") + .action((_, c) => c.copy(server = true)) + .text("run as HTTP server") + + opt[String]("server-host") + .action((x, c) => c.copy(serverHost = x)) + .text("Hostname on which to expose the Chen server") + + opt[Int]("server-port") + .action((x, c) => c.copy(serverPort = x)) + .text("Port on which to expose the Chen server") + + opt[String]("server-auth-username") + .action((x, c) => c.copy(serverAuthUsername = Option(x))) + .text("Basic auth username for the Chen server") + + opt[String]("server-auth-password") + .action((x, c) => c.copy(serverAuthPassword = Option(x))) + .text("Basic auth password for the Chen server") + + note("Misc") + + arg[java.io.File]("") + .optional() + .action((x, c) => c.copy(cpgToLoad = Some(x.toScala))) + .text("Atom to load") + + opt[String]("for-input-path") + .action((x, c) => c.copy(forInputPath = Some(x))) + .text("Open CPG for given input path - overrides ") + + opt[Unit]("nocolors") + .action((_, c) => c.copy(nocolors = true)) + .text("turn off colors") + + opt[Unit]("verbose") + .action((_, c) => c.copy(verbose = true)) + .text("enable verbose output (predef, resolved dependency jars, ...)") + + opt[Int]("maxHeight") + .action((x, c) => c.copy(maxHeight = Some(x))) + .text( + "Maximum number lines to print before output gets truncated (default: no limit)" + ) + + help("help") + .text("Print this help text") + + // note: if config is really `None` an error message would have been displayed earlier + parser.parse(args, Config()).get + end parseConfig + + /** Entry point for Chen's integrated REPL and plugin manager */ + protected def run(config: Config): Unit = + if config.listPlugins then + printPluginsAndLayerCreators(config) + else if config.addPlugin.isDefined then + new PluginManager(InstallConfig().rootPath).add(config.addPlugin.get) + else if config.rmPlugin.isDefined then + new PluginManager(InstallConfig().rootPath).rm(config.rmPlugin.get) + else if config.scriptFile.isDefined then + val scriptReturn = runScript(config) + if scriptReturn.isFailure then + println(scriptReturn.failed.get.getMessage) + System.exit(1) + else if config.server then + GlobalReporting.enable() + startHttpServer(config) + else if config.pluginToRun.isDefined then + runPlugin(config, jProduct.name) + else + startInteractiveShell(config) + + protected def createPredefFile(additionalLines: Seq[String] = Nil): Path = + val tmpFile = Files.createTempFile("chen-predef", "sc") + Files.write(tmpFile, (predefLines ++ additionalLines).asJava) + tmpFile.toAbsolutePath + + /** code that is executed on startup */ + protected def predefLines: Seq[String] + + protected def greeting: String + + protected def promptStr: String + + protected def onExitCode: String end BridgeBase trait InteractiveShell: - this: BridgeBase => - protected def startInteractiveShell(config: Config) = - val replConfig = config.cpgToLoad.map { cpgFile => - "importCpg(\"" + cpgFile + "\")" - } ++ config.forInputPath.map { name => - s""" + this: BridgeBase => + protected def startInteractiveShell(config: Config) = + val replConfig = config.cpgToLoad.map { cpgFile => + "importCpg(\"" + cpgFile + "\")" + } ++ config.forInputPath.map { name => + s""" |openForInputPath(\"$name\") |""".stripMargin - } - - val predefFile = createPredefFile(replConfig.toSeq) - - replpp.InteractiveShell.run( - replpp.Config( - predefFiles = predefFile +: config.additionalImports, - nocolors = config.nocolors, - classpathConfig = replpp.Config - .ForClasspath( - inheritClasspath = true, - dependencies = config.dependencies, - resolvers = config.resolvers - ), - greeting = Option(greeting), - prompt = Option(promptStr), - onExitCode = Option(onExitCode), - maxHeight = config.maxHeight - ) - ) - end startInteractiveShell + } + + val predefFile = createPredefFile(replConfig.toSeq) + + replpp.InteractiveShell.run( + replpp.Config( + predefFiles = predefFile +: config.additionalImports, + nocolors = config.nocolors, + classpathConfig = replpp.Config + .ForClasspath( + inheritClasspath = true, + dependencies = config.dependencies, + resolvers = config.resolvers + ), + greeting = Option(greeting), + prompt = Option(promptStr), + onExitCode = Option(onExitCode), + maxHeight = config.maxHeight + ) + ) + end startInteractiveShell end InteractiveShell trait ScriptExecution: - this: BridgeBase => - - def runScript(config: Config): Try[Unit] = - val scriptFile = - config.scriptFile.getOrElse(throw new AssertionError("no script file configured")) - if !Files.exists(scriptFile) then - Try(throw new AssertionError(s"given script file `$scriptFile` does not exist")) - else - val predefFile = createPredefFile(importCpgCode(config)) - val scriptReturn = ScriptRunner.exec( - replpp.Config( - predefFiles = predefFile +: config.additionalImports, - scriptFile = Option(scriptFile), - command = config.command, - params = config.params, - verbose = config.verbose, - classpathConfig = replpp.Config - .ForClasspath( - inheritClasspath = true, - dependencies = config.dependencies, - resolvers = config.resolvers - ) + this: BridgeBase => + + def runScript(config: Config): Try[Unit] = + val scriptFile = + config.scriptFile.getOrElse(throw new AssertionError("no script file configured")) + if !Files.exists(scriptFile) then + Try(throw new AssertionError(s"given script file `$scriptFile` does not exist")) + else + val predefFile = createPredefFile(importCpgCode(config)) + val scriptReturn = ScriptRunner.exec( + replpp.Config( + predefFiles = predefFile +: config.additionalImports, + scriptFile = Option(scriptFile), + command = config.command, + params = config.params, + verbose = config.verbose, + classpathConfig = replpp.Config + .ForClasspath( + inheritClasspath = true, + dependencies = config.dependencies, + resolvers = config.resolvers ) - ) - if config.verbose && scriptReturn.isFailure then - println(scriptReturn.failed.get.getMessage) - scriptReturn - end if - end runScript - - /** For the given config, generate a list of commands to import the CPG - */ - private def importCpgCode(config: Config): List[String] = - config.cpgToLoad.map { cpgFile => - "importAtom(\"" + cpgFile + "\")" - }.toList ++ config.forInputPath.map { name => - s""" + ) + ) + if config.verbose && scriptReturn.isFailure then + println(scriptReturn.failed.get.getMessage) + scriptReturn + end if + end runScript + + /** For the given config, generate a list of commands to import the CPG + */ + private def importCpgCode(config: Config): List[String] = + config.cpgToLoad.map { cpgFile => + "importAtom(\"" + cpgFile + "\")" + }.toList ++ config.forInputPath.map { name => + s""" |openForInputPath(\"$name\") |""".stripMargin - } + } end ScriptExecution trait PluginHandling: - this: BridgeBase => - - /** Print a summary of the available plugins and layer creators to the terminal. - */ - protected def printPluginsAndLayerCreators(config: Config): Unit = - println("Installed plugins:") - println("==================") - new PluginManager(InstallConfig().rootPath).listPlugins().foreach(println) - println("Available layer creators") - println() - withTemporaryScript(codeToListPlugins(), jProduct.name) { file => - runScript(config.copy(scriptFile = Some(file.path))).get - } - - private def codeToListPlugins(): String = - """ + this: BridgeBase => + + /** Print a summary of the available plugins and layer creators to the terminal. + */ + protected def printPluginsAndLayerCreators(config: Config): Unit = + println("Installed plugins:") + println("==================") + new PluginManager(InstallConfig().rootPath).listPlugins().foreach(println) + println("Available layer creators") + println() + withTemporaryScript(codeToListPlugins(), jProduct.name) { file => + runScript(config.copy(scriptFile = Some(file.path))).get + } + + private def codeToListPlugins(): String = + """ |println(run) | |""".stripMargin - /** Run plugin by generating a temporary script based on the given config and execute the script - */ - protected def runPlugin(config: Config, productName: String): Unit = - if config.src.isEmpty then - println("You must supply a source directory with the --src flag") - return - val code = loadOrCreateCpg(config, productName) - withTemporaryScript(code, productName) { file => - runScript(config.copy(scriptFile = Some(file.path))).get - } - - /** Create a command that loads an existing CPG or creates it, based on the given `config`. - */ - private def loadOrCreateCpg(config: Config, productName: String): String = - - val bundleName = config.pluginToRun.get - val src = better.files.File(config.src.get).path.toAbsolutePath.toString - val language = languageFromConfig(config, src) - - val storeCode = if config.store then "save" - else "" - val runDataflow = "run.ossdataflow" - val argsString = argsStringFromConfig(config) - - s""" + /** Run plugin by generating a temporary script based on the given config and execute the script + */ + protected def runPlugin(config: Config, productName: String): Unit = + if config.src.isEmpty then + println("You must supply a source directory with the --src flag") + return + val code = loadOrCreateCpg(config, productName) + withTemporaryScript(code, productName) { file => + runScript(config.copy(scriptFile = Some(file.path))).get + } + + /** Create a command that loads an existing CPG or creates it, based on the given `config`. + */ + private def loadOrCreateCpg(config: Config, productName: String): String = + + val bundleName = config.pluginToRun.get + val src = better.files.File(config.src.get).path.toAbsolutePath.toString + val language = languageFromConfig(config, src) + + val storeCode = if config.store then "save" + else "" + val runDataflow = "run.ossdataflow" + val argsString = argsStringFromConfig(config) + + s""" | if (${config.overwrite} || !workspace.projectExists("$src")) { | workspace.projects | .filter(_.inputPath == "$src") @@ -358,64 +358,64 @@ trait PluginHandling: | run.$bundleName | $storeCode |""".stripMargin - end loadOrCreateCpg - - private def languageFromConfig(config: Config, src: String): String = - config.language.getOrElse( - io.appthreat.console.cpgcreation - .guessLanguage(src) - .map { - case Languages.C | Languages.NEWC => "c" - case Languages.JAVA => "jvm" - case Languages.JAVASRC => "java" - case lang => lang.toLowerCase - } - .getOrElse("c") - ) - - private def argsStringFromConfig(config: Config): String = - config.frontendArgs match - case Array() => "" - case args => - val quotedArgs = args.map { arg => - "\"" ++ arg ++ "\"" - } - val argsString = quotedArgs.mkString(", ") - s", args=List($argsString)" - - private def withTemporaryScript(code: String, prefix: String)(f: File => Unit): Unit = - File.usingTemporaryDirectory(prefix + "-bundle") { dir => - val file = dir / "script.sc" - file.write(code) - f(file) - } + end loadOrCreateCpg + + private def languageFromConfig(config: Config, src: String): String = + config.language.getOrElse( + io.appthreat.console.cpgcreation + .guessLanguage(src) + .map { + case Languages.C | Languages.NEWC => "c" + case Languages.JAVA => "jvm" + case Languages.JAVASRC => "java" + case lang => lang.toLowerCase + } + .getOrElse("c") + ) + + private def argsStringFromConfig(config: Config): String = + config.frontendArgs match + case Array() => "" + case args => + val quotedArgs = args.map { arg => + "\"" ++ arg ++ "\"" + } + val argsString = quotedArgs.mkString(", ") + s", args=List($argsString)" + + private def withTemporaryScript(code: String, prefix: String)(f: File => Unit): Unit = + File.usingTemporaryDirectory(prefix + "-bundle") { dir => + val file = dir / "script.sc" + file.write(code) + f(file) + } end PluginHandling trait ServerHandling: - this: BridgeBase => - - protected def startHttpServer(config: Config): Unit = - val predefFile = createPredefFile(Nil) - - val baseConfig = replpp.Config( - predefFiles = predefFile +: config.additionalImports, - verbose = true, // always print what's happening - helps debugging - classpathConfig = replpp.Config - .ForClasspath( - inheritClasspath = true, - dependencies = config.dependencies, - resolvers = config.resolvers - ) - ) - - replpp.server.ReplServer.startHttpServer( - replpp.server.Config( - baseConfig, - serverHost = config.serverHost, - serverPort = config.serverPort, - serverAuthUsername = config.serverAuthUsername, - serverAuthPassword = config.serverAuthPassword + this: BridgeBase => + + protected def startHttpServer(config: Config): Unit = + val predefFile = createPredefFile(Nil) + + val baseConfig = replpp.Config( + predefFiles = predefFile +: config.additionalImports, + verbose = true, // always print what's happening - helps debugging + classpathConfig = replpp.Config + .ForClasspath( + inheritClasspath = true, + dependencies = config.dependencies, + resolvers = config.resolvers ) - ) - end startHttpServer + ) + + replpp.server.ReplServer.startHttpServer( + replpp.server.Config( + baseConfig, + serverHost = config.serverHost, + serverPort = config.serverPort, + serverAuthUsername = config.serverAuthUsername, + serverAuthPassword = config.serverAuthPassword + ) + ) + end startHttpServer end ServerHandling diff --git a/console/src/main/scala/io/appthreat/console/Commit.scala b/console/src/main/scala/io/appthreat/console/Commit.scala index 0e8a892f..7f2f94c8 100644 --- a/console/src/main/scala/io/appthreat/console/Commit.scala +++ b/console/src/main/scala/io/appthreat/console/Commit.scala @@ -5,22 +5,22 @@ import io.shiftleft.semanticcpg.layers.{LayerCreator, LayerCreatorContext, Layer import overflowdb.BatchedUpdate.DiffGraphBuilder object Commit: - val overlayName: String = "commit" - val description: String = "Apply current custom diffgraph" - def defaultOpts = new CommitOptions(new DiffGraphBuilder) + val overlayName: String = "commit" + val description: String = "Apply current custom diffgraph" + def defaultOpts = new CommitOptions(new DiffGraphBuilder) class CommitOptions(var diffGraphBuilder: DiffGraphBuilder) extends LayerCreatorOptions class Commit(opts: CommitOptions) extends LayerCreator: - override val overlayName: String = Commit.overlayName - override val description: String = Commit.description - override val storeOverlayName: Boolean = false + override val overlayName: String = Commit.overlayName + override val description: String = Commit.description + override val storeOverlayName: Boolean = false - override def create(context: LayerCreatorContext, storeUndoInfo: Boolean): Unit = - val pass: CpgPass = new CpgPass(context.cpg): - override val name = "commit" - override def run(builder: DiffGraphBuilder): Unit = - builder.absorb(opts.diffGraphBuilder) - runPass(pass, context, storeUndoInfo) - opts.diffGraphBuilder = new DiffGraphBuilder + override def create(context: LayerCreatorContext, storeUndoInfo: Boolean): Unit = + val pass: CpgPass = new CpgPass(context.cpg): + override val name = "commit" + override def run(builder: DiffGraphBuilder): Unit = + builder.absorb(opts.diffGraphBuilder) + runPass(pass, context, storeUndoInfo) + opts.diffGraphBuilder = new DiffGraphBuilder diff --git a/console/src/main/scala/io/appthreat/console/Console.scala b/console/src/main/scala/io/appthreat/console/Console.scala index 87971fe4..4fb3df57 100644 --- a/console/src/main/scala/io/appthreat/console/Console.scala +++ b/console/src/main/scala/io/appthreat/console/Console.scala @@ -27,61 +27,61 @@ class Console[T <: Project]( baseDir: File = File.currentWorkingDirectory ) extends Reporting: - import Console.* - - private val _config = new ConsoleConfig() - def config: ConsoleConfig = _config - def console: Console[T] = this - - protected var workspaceManager: WorkspaceManager[T] = scala.compiletime.uninitialized - switchWorkspace(baseDir.path.resolve("workspace").toString) - protected def workspacePathName: String = workspaceManager.getPath - - private val nameOfCpgInProject = "cpg.bin" - implicit val resolver: ICallResolver = NoResolve - implicit val finder: NodeExtensionFinder = DefaultNodeExtensionFinder - - implicit val pyGlobal: me.shadaj.scalapy.py.Dynamic.global.type = py.Dynamic.global - var richTableLib: me.shadaj.scalapy.py.Dynamic = py.module("logging") - var richTreeLib: me.shadaj.scalapy.py.Dynamic = py.module("logging") - var richProgressLib: me.shadaj.scalapy.py.Dynamic = py.module("logging") - var richConsole: me.shadaj.scalapy.py.Dynamic = py.module("logging") - var richAvailable = true - try - richTableLib = py.module("rich.table") - richTreeLib = py.module("rich.tree") - richProgressLib = py.module("rich.progress") - richConsole = py.module("chenpy.logger").console - catch - case _: Exception => richAvailable = false - - implicit object ConsoleImageViewer extends ImageViewer: - def view(imagePathStr: String): Try[String] = - // We need to copy the file as the original one is only temporary - // and gets removed immediately after running this viewer instance asynchronously via .run(). - val tmpFile = - File(imagePathStr).copyTo(File.newTemporaryFile(suffix = ".svg"), overwrite = true) - tmpFile.deleteOnExit(swallowIOExceptions = true) - Try { - val command = if scala.util.Properties.isWin then - Seq("cmd.exe", "/C", config.tools.imageViewer) - else Seq(config.tools.imageViewer) - Process(command :+ tmpFile.path.toAbsolutePath.toString).run() - } match - case Success(_) => - // We never handle the actual result anywhere. - // Hence, we just pass a success message. - Success(s"Running viewer for '$tmpFile' finished.") - case Failure(exc) => - System.err.println("Executing image viewer failed. Is it installed? ") - System.err.println(exc) - Failure(exc) - end view - end ConsoleImageViewer - - @Doc( - info = "Access to the workspace directory", - longInfo = """ + import Console.* + + private val _config = new ConsoleConfig() + def config: ConsoleConfig = _config + def console: Console[T] = this + + protected var workspaceManager: WorkspaceManager[T] = scala.compiletime.uninitialized + switchWorkspace(baseDir.path.resolve("workspace").toString) + protected def workspacePathName: String = workspaceManager.getPath + + private val nameOfCpgInProject = "cpg.bin" + implicit val resolver: ICallResolver = NoResolve + implicit val finder: NodeExtensionFinder = DefaultNodeExtensionFinder + + implicit val pyGlobal: me.shadaj.scalapy.py.Dynamic.global.type = py.Dynamic.global + var richTableLib: me.shadaj.scalapy.py.Dynamic = py.module("logging") + var richTreeLib: me.shadaj.scalapy.py.Dynamic = py.module("logging") + var richProgressLib: me.shadaj.scalapy.py.Dynamic = py.module("logging") + var richConsole: me.shadaj.scalapy.py.Dynamic = py.module("logging") + var richAvailable = true + try + richTableLib = py.module("rich.table") + richTreeLib = py.module("rich.tree") + richProgressLib = py.module("rich.progress") + richConsole = py.module("chenpy.logger").console + catch + case _: Exception => richAvailable = false + + implicit object ConsoleImageViewer extends ImageViewer: + def view(imagePathStr: String): Try[String] = + // We need to copy the file as the original one is only temporary + // and gets removed immediately after running this viewer instance asynchronously via .run(). + val tmpFile = + File(imagePathStr).copyTo(File.newTemporaryFile(suffix = ".svg"), overwrite = true) + tmpFile.deleteOnExit(swallowIOExceptions = true) + Try { + val command = if scala.util.Properties.isWin then + Seq("cmd.exe", "/C", config.tools.imageViewer) + else Seq(config.tools.imageViewer) + Process(command :+ tmpFile.path.toAbsolutePath.toString).run() + } match + case Success(_) => + // We never handle the actual result anywhere. + // Hence, we just pass a success message. + Success(s"Running viewer for '$tmpFile' finished.") + case Failure(exc) => + System.err.println("Executing image viewer failed. Is it installed? ") + System.err.println(exc) + Failure(exc) + end view + end ConsoleImageViewer + + @Doc( + info = "Access to the workspace directory", + longInfo = """ |All auditing projects are stored in a workspace directory, and `workspace` |provides programmatic access to this directory. Entering `workspace` provides |a list of all projects, indicating which code the project makes accessible, @@ -110,13 +110,13 @@ class Console[T <: Project]( |workspace directory | |""", - example = "workspace" - ) - def workspace: WorkspaceManager[T] = workspaceManager + example = "workspace" + ) + def workspace: WorkspaceManager[T] = workspaceManager - @Doc( - info = "Close current workspace and open a different one", - longInfo = """ | By default, the workspace in $INSTALL_DIR/workspace is used. + @Doc( + info = "Close current workspace and open a different one", + longInfo = """ | By default, the workspace in $INSTALL_DIR/workspace is used. | This method allows specifying a different workspace directory | via the `pathName` parameter. | Before changing the workspace, the current workspace will be @@ -124,22 +124,22 @@ class Console[T <: Project]( | If `pathName` points to a non-existing directory, then a new | workspace is first created. |""" - ) - def switchWorkspace(pathName: String): Unit = - if workspaceManager != null then - report("Saving current workspace before changing workspace") - workspaceManager.projects.foreach { p => - p.close - } - workspaceManager = new WorkspaceManager[T](pathName, loader) - - @Doc(info = "Currently active project", example = "project") - def project: T = - workspace.projectByCpg(cpg).getOrElse(throw new RuntimeException("No active project")) - - @Doc( - info = "CPG of the active project", - longInfo = """ + ) + def switchWorkspace(pathName: String): Unit = + if workspaceManager != null then + report("Saving current workspace before changing workspace") + workspaceManager.projects.foreach { p => + p.close + } + workspaceManager = new WorkspaceManager[T](pathName, loader) + + @Doc(info = "Currently active project", example = "project") + def project: T = + workspace.projectByCpg(cpg).getOrElse(throw new RuntimeException("No active project")) + + @Doc( + info = "CPG of the active project", + longInfo = """ |Upon importing code, a project is created that holds |an intermediate representation called `Code Property Graph`. This |graph is a composition of low-level program representations such @@ -153,31 +153,31 @@ class Console[T <: Project]( |`cpg.method.l` lists all methods, while `cpg.finding.l` lists all findings |of potentially vulnerable code. |""", - example = "cpg.method.l" - ) - implicit def cpg: Cpg = workspace.cpg - def atom: Cpg = workspace.cpg - - /** All cpgs loaded in the workspace - */ - def cpgs: Iterator[Cpg] = - if workspace.projects.lastOption.isEmpty then - Iterator() - else - val activeProjectName = project.name - (workspace.projects.filter(_.cpg.isDefined).iterator.flatMap { project => - open(project.name) - Some(project.cpg) - } ++ Iterator({ open(activeProjectName); None })).flatten - - // Provide `.l` on iterators, specifically so - // that `cpgs.flatMap($query).l` is possible - implicit class ItExtend[X](it: Iterator[X]): - def l: List[X] = it.toList - - @Doc( - info = "Open project by name", - longInfo = """ + example = "cpg.method.l" + ) + implicit def cpg: Cpg = workspace.cpg + def atom: Cpg = workspace.cpg + + /** All cpgs loaded in the workspace + */ + def cpgs: Iterator[Cpg] = + if workspace.projects.lastOption.isEmpty then + Iterator() + else + val activeProjectName = project.name + (workspace.projects.filter(_.cpg.isDefined).iterator.flatMap { project => + open(project.name) + Some(project.cpg) + } ++ Iterator({ open(activeProjectName); None })).flatten + + // Provide `.l` on iterators, specifically so + // that `cpgs.flatMap($query).l` is possible + implicit class ItExtend[X](it: Iterator[X]): + def l: List[X] = it.toList + + @Doc( + info = "Open project by name", + longInfo = """ |open([projectName]) | |Opens the project named `name` and make it the active project. @@ -189,18 +189,18 @@ class Console[T <: Project]( |can be queried via `cpg`. Returns an optional reference to the |project, which is empty on error. |""", - example = """open("projectName")""" - ) - def open(name: String): Option[Project] = - val projectName = fixProjectNameAndComplainOnFix(name) - workspace.openProject(projectName).map { project => - project - } - end open - - @Doc( - info = "Open project for input path", - longInfo = """ + example = """open("projectName")""" + ) + def open(name: String): Option[Project] = + val projectName = fixProjectNameAndComplainOnFix(name) + workspace.openProject(projectName).map { project => + project + } + end open + + @Doc( + info = "Open project for input path", + longInfo = """ |openForInputPath([input-path]) | |Opens the project of the CPG generated for the input path `input-path`. @@ -209,73 +209,73 @@ class Console[T <: Project]( |can be queried via `cpg`. Returns an optional reference to the |project, which is empty on error. |""" - ) - def openForInputPath(inputPath: String): Option[Project] = - val absInputPath = File(inputPath).path.toAbsolutePath.toString - workspace.projects - .filter(x => x.inputPath == absInputPath) - .map(_.name) - .map(open) - .headOption - .flatten - end openForInputPath - - /** Open the active project - */ - def open: Option[Project] = - workspace.projects.lastOption.flatMap { p => - open(p.name) - } - - /** Delete project from disk and remove it from the workspace manager. Returns the (now invalid) - * project. - * @param name - * the name of the project - */ - @Doc(info = "Close and remove project from disk", example = "delete(projectName)") - def delete(name: String): Option[Unit] = - workspaceManager.getActiveProject.foreach(_.cpg.foreach(_.close())) - defaultProjectNameIfEmpty(name).flatMap(workspace.deleteProject) - - @Doc(info = "Exit the REPL") - def exit: Unit = - workspace.projects.foreach(_.close) - System.exit(0) - - /** Delete the active project - */ - def delete: Option[Unit] = delete("") - - protected def defaultProjectNameIfEmpty(name: String): Option[String] = - if name.isEmpty then - val projectNameOpt = workspace.projectByCpg(cpg).map(_.name) - if projectNameOpt.isEmpty then - report("Fatal: cannot find project for active CPG") - projectNameOpt - else - Some(fixProjectNameAndComplainOnFix(name)) - - @Doc( - info = "Write all changes to disk", - longInfo = """ + ) + def openForInputPath(inputPath: String): Option[Project] = + val absInputPath = File(inputPath).path.toAbsolutePath.toString + workspace.projects + .filter(x => x.inputPath == absInputPath) + .map(_.name) + .map(open) + .headOption + .flatten + end openForInputPath + + /** Open the active project + */ + def open: Option[Project] = + workspace.projects.lastOption.flatMap { p => + open(p.name) + } + + /** Delete project from disk and remove it from the workspace manager. Returns the (now invalid) + * project. + * @param name + * the name of the project + */ + @Doc(info = "Close and remove project from disk", example = "delete(projectName)") + def delete(name: String): Option[Unit] = + workspaceManager.getActiveProject.foreach(_.cpg.foreach(_.close())) + defaultProjectNameIfEmpty(name).flatMap(workspace.deleteProject) + + @Doc(info = "Exit the REPL") + def exit: Unit = + workspace.projects.foreach(_.close) + System.exit(0) + + /** Delete the active project + */ + def delete: Option[Unit] = delete("") + + protected def defaultProjectNameIfEmpty(name: String): Option[String] = + if name.isEmpty then + val projectNameOpt = workspace.projectByCpg(cpg).map(_.name) + if projectNameOpt.isEmpty then + report("Fatal: cannot find project for active CPG") + projectNameOpt + else + Some(fixProjectNameAndComplainOnFix(name)) + + @Doc( + info = "Write all changes to disk", + longInfo = """ |Close and reopen all loaded CPGs. This ensures |that changes have been flushed to disk. | |Returns list of affected projects |""", - example = "save" - ) - def save: List[Project] = - report("Saving graphs on disk. This may take a while.") - workspace.projects.collect { - case p: Project if p.cpg.isDefined => - p.close - workspace.openProject(p.name) - }.flatten - - @Doc( - info = "Create new project from code", - longInfo = """ + example = "save" + ) + def save: List[Project] = + report("Saving graphs on disk. This may take a while.") + workspace.projects.collect { + case p: Project if p.cpg.isDefined => + p.close + workspace.openProject(p.name) + }.flatten + + @Doc( + info = "Create new project from code", + longInfo = """ |importCode(, [projectName], [namespaces], [language]) | |Type `importCode` alone to get a list of all supported languages @@ -312,13 +312,13 @@ class Console[T <: Project]( |the filename found and possibly by looking into the file/directory. | |""", - example = """importCode("git url or path")""" - ) - def importCode = new ImportCode(this) + example = """importCode("git url or path")""" + ) + def importCode = new ImportCode(this) - @Doc( - info = "Create new project from existing atom", - longInfo = """ + @Doc( + info = "Create new project from existing atom", + longInfo = """ |importAtom(, [projectName]) | |Import an existing atom. @@ -332,15 +332,15 @@ class Console[T <: Project]( |is omitted, the path is derived from `inputPath` | |""", - example = """importAtom("app.atom")""" - ) - def importAtom(inputPath: String, projectName: String = ""): Unit = - importCpg(inputPath, projectName, false) - summary - end importAtom - @Doc( - info = "Create new project from existing CPG", - longInfo = """ + example = """importAtom("app.atom")""" + ) + def importAtom(inputPath: String, projectName: String = ""): Unit = + importCpg(inputPath, projectName, false) + summary + end importAtom + @Doc( + info = "Create new project from existing CPG", + longInfo = """ |importCpg(, [projectName], [enhance]) | |Import an existing CPG. The CPG is stored as part @@ -359,160 +359,160 @@ class Console[T <: Project]( |enhance: run default overlays and post-processing passes. Defaults to `true`. |Pass `enhance=false` to disable the enhancements. |""", - example = """importCpg("app.atom")""" - ) - def importCpg( - inputPath: String, - projectName: String = "", - enhance: Boolean = true - ): Option[Cpg] = - val name = - Option(projectName).filter(_.nonEmpty).getOrElse(deriveNameFromInputPath( - inputPath, - workspace - )) - - var cpgFile = File(inputPath) - if !cpgFile.exists then - cpgFile = - if inputPath.endsWith(".atom") || inputPath.endsWith(".⚛") || inputPath.endsWith( - ".zip" - ) || inputPath.endsWith(".cpg") || inputPath.endsWith(".bin") - then - File(inputPath) - else File(inputPath) / "app.atom" - if !cpgFile.exists then - report(s"CPG at $inputPath does not exist. Bailing out.") + example = """importCpg("app.atom")""" + ) + def importCpg( + inputPath: String, + projectName: String = "", + enhance: Boolean = true + ): Option[Cpg] = + val name = + Option(projectName).filter(_.nonEmpty).getOrElse(deriveNameFromInputPath( + inputPath, + workspace + )) + + var cpgFile = File(inputPath) + if !cpgFile.exists then + cpgFile = + if inputPath.endsWith(".atom") || inputPath.endsWith(".⚛") || inputPath.endsWith( + ".zip" + ) || inputPath.endsWith(".cpg") || inputPath.endsWith(".bin") + then + File(inputPath) + else File(inputPath) / "app.atom" + if !cpgFile.exists then + report(s"CPG at $inputPath does not exist. Bailing out.") + return None + val pathToProject = workspace.createProject(inputPath, name) + val cpgDestinationPathOpt = pathToProject.map(_.resolve(nameOfCpgInProject)) + + if cpgDestinationPathOpt.isEmpty then + report(s"Error creating project for input path: `$inputPath`") + return None + + val cpgDestinationPath = cpgDestinationPathOpt.get + + if CpgLoader.isLegacyCpg(cpgFile) then + report("You have provided a legacy proto CPG. Attempting conversion.") + try + CpgConverter.convertProtoCpgToOverflowDb( + cpgFile.path.toString, + cpgDestinationPath.toString + ) + catch + case exc: Exception => + report("Error converting legacy CPG: " + exc.getMessage) return None - val pathToProject = workspace.createProject(inputPath, name) - val cpgDestinationPathOpt = pathToProject.map(_.resolve(nameOfCpgInProject)) + else + cpgFile.copyTo(cpgDestinationPath, overwrite = true) - if cpgDestinationPathOpt.isEmpty then - report(s"Error creating project for input path: `$inputPath`") - return None + val cpgOpt = open(name).flatMap(_.cpg) + + if cpgOpt.isEmpty then + workspace.deleteProject(name) + + cpgOpt + .filter(_.metaData.hasNext) + .foreach { cpg => + if enhance then applyDefaultOverlays(cpg) + applyPostProcessingPasses(cpg) + } + cpgOpt + end importCpg - val cpgDestinationPath = cpgDestinationPathOpt.get - - if CpgLoader.isLegacyCpg(cpgFile) then - report("You have provided a legacy proto CPG. Attempting conversion.") - try - CpgConverter.convertProtoCpgToOverflowDb( - cpgFile.path.toString, - cpgDestinationPath.toString - ) - catch - case exc: Exception => - report("Error converting legacy CPG: " + exc.getMessage) - return None - else - cpgFile.copyTo(cpgDestinationPath, overwrite = true) - - val cpgOpt = open(name).flatMap(_.cpg) - - if cpgOpt.isEmpty then - workspace.deleteProject(name) - - cpgOpt - .filter(_.metaData.hasNext) - .foreach { cpg => - if enhance then applyDefaultOverlays(cpg) - applyPostProcessingPasses(cpg) - } - cpgOpt - end importCpg - - @Doc( - info = "Close project by name", - longInfo = """|Close project. Resources are freed but the project remains on disk. + @Doc( + info = "Close project by name", + longInfo = """|Close project. Resources are freed but the project remains on disk. |The project remains active, that is, calling `cpg` now raises an |exception. A different project can now be activated using `open`. |""", - example = "close(projectName)" - ) - def close(name: String): Option[Project] = - defaultProjectNameIfEmpty(name).flatMap(workspace.closeProject) - - def close: Option[Project] = close("") - - /** Close the project and open it again. - * - * @param name - * the name of the project - */ - def reload(name: String): Option[Project] = - close(name).flatMap(p => open(p.name)) - - @Doc( - info = "Display summary information", - longInfo = - """|Displays summary about the loaded atom such as the number of files, methods, annotations etc. + example = "close(projectName)" + ) + def close(name: String): Option[Project] = + defaultProjectNameIfEmpty(name).flatMap(workspace.closeProject) + + def close: Option[Project] = close("") + + /** Close the project and open it again. + * + * @param name + * the name of the project + */ + def reload(name: String): Option[Project] = + close(name).flatMap(p => open(p.name)) + + @Doc( + info = "Display summary information", + longInfo = + """|Displays summary about the loaded atom such as the number of files, methods, annotations etc. |Requires the python modules to be installed. |""", - example = "summary" + example = "summary" + ) + def summary(as_text: Boolean): String = + val table = richTableLib.Table(title = "Atom Summary") + table.add_column("Node Type") + table.add_column("Count") + table.add_row("Files", "" + atom.file.size) + table.add_row("Methods", "" + atom.method.size) + table.add_row("Annotations", "" + atom.annotation.size) + table.add_row("Imports", "" + atom.imports.size) + table.add_row("Literals", "" + atom.literal.size) + table.add_row("Config Files", "" + atom.configFile.size) + table.add_row( + "Validation tags", + "[#5A7C90]" + atom.tag.name("(validation|sanitization).*").name.size + "[/#5A7C90]" ) - def summary(as_text: Boolean): String = - val table = richTableLib.Table(title = "Atom Summary") - table.add_column("Node Type") - table.add_column("Count") - table.add_row("Files", "" + atom.file.size) - table.add_row("Methods", "" + atom.method.size) - table.add_row("Annotations", "" + atom.annotation.size) - table.add_row("Imports", "" + atom.imports.size) - table.add_row("Literals", "" + atom.literal.size) - table.add_row("Config Files", "" + atom.configFile.size) - table.add_row( - "Validation tags", - "[#5A7C90]" + atom.tag.name("(validation|sanitization).*").name.size + "[/#5A7C90]" - ) - table.add_row( - "Unique packages", - "[#5A7C90]" + atom.tag.name("pkg.*").name.dedup.size + "[/#5A7C90]" - ) - table.add_row( - "Framework tags", - "[#5A7C90]" + atom.tag.name("framework.*").name.size + "[/#5A7C90]" - ) - table.add_row( - "Framework input", - "[#5A7C90]" + atom.tag.name("framework-(input|route)").name.size + "[/#5A7C90]" - ) - table.add_row( - "Framework output", - "[#5A7C90]" + atom.tag.name("framework-output").name.size + "[/#5A7C90]" - ) - table.add_row( - "Crypto tags", - "[#5A7C90]" + atom.tag.name("crypto.*").name.size + "[/#5A7C90]" - ) - val appliedOverlays = Overlays.appliedOverlays(atom) - if appliedOverlays.nonEmpty then table.add_row("Overlays", "" + appliedOverlays.size) - richConsole.clear() - richConsole.print(table) - if as_text then richConsole.export_text().as[String] else "" - end summary - def summary: String = summary(as_text = false) - - @Doc( - info = "List files", - longInfo = """|Lists the files from the loaded atom. + table.add_row( + "Unique packages", + "[#5A7C90]" + atom.tag.name("pkg.*").name.dedup.size + "[/#5A7C90]" + ) + table.add_row( + "Framework tags", + "[#5A7C90]" + atom.tag.name("framework.*").name.size + "[/#5A7C90]" + ) + table.add_row( + "Framework input", + "[#5A7C90]" + atom.tag.name("framework-(input|route)").name.size + "[/#5A7C90]" + ) + table.add_row( + "Framework output", + "[#5A7C90]" + atom.tag.name("framework-output").name.size + "[/#5A7C90]" + ) + table.add_row( + "Crypto tags", + "[#5A7C90]" + atom.tag.name("crypto.*").name.size + "[/#5A7C90]" + ) + val appliedOverlays = Overlays.appliedOverlays(atom) + if appliedOverlays.nonEmpty then table.add_row("Overlays", "" + appliedOverlays.size) + richConsole.clear() + richConsole.print(table) + if as_text then richConsole.export_text().as[String] else "" + end summary + def summary: String = summary(as_text = false) + + @Doc( + info = "List files", + longInfo = """|Lists the files from the loaded atom. |Requires the python modules to be installed. |""", - example = "files" - ) - def files(title: String = "Files", as_text: Boolean): String = - val table = richTableLib.Table(title = title, highlight = true) - table.add_column("File Name") - table.add_column("Method Count") - atom.file.whereNot(_.name("")).foreach { f => - table.add_row(f.name, "" + f.method.size) - } - richConsole.print(table) - if as_text then richConsole.export_text().as[String] else "" - def files: String = files("Files", as_text = false) - - @Doc( - info = "List methods", - longInfo = """|Lists the methods by files from the loaded atom. + example = "files" + ) + def files(title: String = "Files", as_text: Boolean): String = + val table = richTableLib.Table(title = title, highlight = true) + table.add_column("File Name") + table.add_column("Method Count") + atom.file.whereNot(_.name("")).foreach { f => + table.add_row(f.name, "" + f.method.size) + } + richConsole.print(table) + if as_text then richConsole.export_text().as[String] else "" + def files: String = files("Files", as_text = false) + + @Doc( + info = "List methods", + longInfo = """|Lists the methods by files from the loaded atom. |Requires the python modules to be installed. | |Parameters: @@ -520,366 +520,367 @@ class Console[T <: Project]( |title: Title for the table. Default Methods. |tree: Display as a tree instead of table |""", - example = "methods('Methods', includeCalls=true, tree=true)" - ) - def methods( - title: String = "Methods", - includeCalls: Boolean = false, - tree: Boolean = false, - as_text: Boolean = false - ): String = - if tree || includeCalls then - val rootTree = richTreeLib.Tree(title, highlight = true) - atom.file.whereNot(_.name("<(unknown|includes)>")).foreach { f => - val childTree = richTreeLib.Tree(f.name, highlight = true) - f.method.foreach(m => - val addedMethods = mutable.Map.empty[String, Boolean] - val mtree = childTree.add(m.fullName) - if includeCalls then - m.call - .filterNot(_.name.startsWith(" - if !addedMethods.contains( - c.methodFullName - ) && c.methodFullName != "" && !c.methodFullName.startsWith( - "{ " - ) - then - mtree - .add( - c.methodFullName + (if c.callee( - NoResolve - ).nonEmpty && c.callee( - NoResolve - ).head.nonEmpty && c.callee( - NoResolve - ).head.isExternal - then " :right_arrow_curving_up:" - else "") - ) - addedMethods += c.methodFullName -> true + example = "methods('Methods', includeCalls=true, tree=true)" + ) + def methods( + title: String = "Methods", + includeCalls: Boolean = false, + tree: Boolean = false, + as_text: Boolean = false + ): String = + if tree || includeCalls then + val rootTree = richTreeLib.Tree(title, highlight = true) + atom.file.whereNot(_.name("<(unknown|includes)>")).foreach { f => + val childTree = richTreeLib.Tree(f.name, highlight = true) + f.method.foreach(m => + val addedMethods = mutable.Map.empty[String, Boolean] + val mtree = childTree.add(m.fullName) + if includeCalls then + m.call + .filterNot(_.name.startsWith(" + if !addedMethods.contains( + c.methodFullName + ) && c.methodFullName != "" && !c + .methodFullName.startsWith( + "{ " ) - end if - ) - rootTree.add(childTree) - } - richConsole.print(rootTree) - if as_text then richConsole.export_text().as[String] else "" - else - val table = richTableLib.Table(title = title, highlight = true, show_lines = true) - table.add_column("File Name") - table.add_column("Methods") - atom.file.whereNot(_.name("<(unknown|includes)>")).foreach { f => - table.add_row( - f.name, - f.method.filterNot(m => - m.fullName.endsWith("") || m.fullName.endsWith( - "" - ) || m.fullName.endsWith( - "" - ) || m.name.isEmpty - ).map(m => - var methodDisplayStr = if m.tag.nonEmpty then - s"""${m.fullName}\n[info]Tags: ${m.tag.name.mkString(", ")}[/info]""" - else m.fullName - if m.tag.nonEmpty && (m.tag.name.contains( - "validation" - ) || m.tag.name.contains("sanitization") || m.tag.name.contains( - "authentication" - ) || m.tag.name.contains("authorization")) - then methodDisplayStr = s"[green]$methodDisplayStr[/green]" - methodDisplayStr - ).l.mkString("\n") - ) - } - richConsole.print(table) - if as_text then richConsole.export_text().as[String] else "" - def methods: String = methods("Methods", as_text = false) - - @Doc( - info = "List annotations", - longInfo = """|Lists the method annotations by files from the loaded atom. - |Requires the python modules to be installed. - |""", - example = "annotations" - ) - def annotations(title: String = "Annotations", as_text: Boolean): String = + then + mtree + .add( + c.methodFullName + (if c.callee( + NoResolve + ).nonEmpty && c.callee( + NoResolve + ).head.nonEmpty && c.callee( + NoResolve + ).head.isExternal + then " :right_arrow_curving_up:" + else "") + ) + addedMethods += c.methodFullName -> true + ) + end if + ) + rootTree.add(childTree) + } + richConsole.print(rootTree) + if as_text then richConsole.export_text().as[String] else "" + else val table = richTableLib.Table(title = title, highlight = true, show_lines = true) table.add_column("File Name") table.add_column("Methods") - table.add_column("Annotations") - atom.file.whereNot(_.name("")).method.filter(_.annotation.nonEmpty).foreach { m => - table.add_row(m.location.filename, m.fullName, m.annotation.fullName.l.mkString("\n")) + atom.file.whereNot(_.name("<(unknown|includes)>")).foreach { f => + table.add_row( + f.name, + f.method.filterNot(m => + m.fullName.endsWith("") || m.fullName.endsWith( + "" + ) || m.fullName.endsWith( + "" + ) || m.name.isEmpty + ).map(m => + var methodDisplayStr = if m.tag.nonEmpty then + s"""${m.fullName}\n[info]Tags: ${m.tag.name.mkString(", ")}[/info]""" + else m.fullName + if m.tag.nonEmpty && (m.tag.name.contains( + "validation" + ) || m.tag.name.contains("sanitization") || m.tag.name.contains( + "authentication" + ) || m.tag.name.contains("authorization")) + then methodDisplayStr = s"[green]$methodDisplayStr[/green]" + methodDisplayStr + ).l.mkString("\n") + ) } richConsole.print(table) if as_text then richConsole.export_text().as[String] else "" + def methods: String = methods("Methods", as_text = false) - def annotations: String = annotations("Annotations", as_text = false) - - @Doc( - info = "List imports", - longInfo = """|Lists the imports by files from the loaded atom. + @Doc( + info = "List annotations", + longInfo = """|Lists the method annotations by files from the loaded atom. |Requires the python modules to be installed. |""", - example = "imports" - ) - def imports(title: String = "Imports", as_text: Boolean): String = - val table = richTableLib.Table(title = title, highlight = true, show_lines = true) - table.add_column("File Name") - table.add_column("Import") - atom.imports.foreach { i => - table.add_row(i.file.name.l.mkString("\n"), i.importedEntity.getOrElse("")) - } - richConsole.print(table) - if as_text then richConsole.export_text().as[String] else "" - - def imports: String = imports("Imports", as_text = false) - - @Doc( - info = "List declarations", - longInfo = """|Lists the declarations by files from the loaded atom. + example = "annotations" + ) + def annotations(title: String = "Annotations", as_text: Boolean): String = + val table = richTableLib.Table(title = title, highlight = true, show_lines = true) + table.add_column("File Name") + table.add_column("Methods") + table.add_column("Annotations") + atom.file.whereNot(_.name("")).method.filter(_.annotation.nonEmpty).foreach { m => + table.add_row(m.location.filename, m.fullName, m.annotation.fullName.l.mkString("\n")) + } + richConsole.print(table) + if as_text then richConsole.export_text().as[String] else "" + + def annotations: String = annotations("Annotations", as_text = false) + + @Doc( + info = "List imports", + longInfo = """|Lists the imports by files from the loaded atom. |Requires the python modules to be installed. |""", - example = "declarations" - ) - def declarations(title: String = "Declarations", as_text: Boolean): String = - val table = richTableLib.Table(title = title, highlight = true, show_lines = true) - table.add_column("File Name") - table.add_column("Declarations") - atom.file.whereNot(_.name("")).foreach { f => - val dec: Set[Declaration] = - (f.assignment.argument(1).filterNot( - _.code == "this" - ).isIdentifier.nameNot("tmp[0-9]+$").refsTo ++ f.method.parameter - .filterNot(_.code == "this") - .filter(_.typeFullName != "ANY")).toSet - table.add_row(f.name, dec.name.toSet.mkString("\n")) - } - richConsole.print(table) - if as_text then richConsole.export_text().as[String] else "" - end declarations - - def declarations: String = declarations("Declarations", as_text = false) - - @Doc( - info = "List sensitive literals", - longInfo = """|Lists the sensitive literals by files from the loaded atom. + example = "imports" + ) + def imports(title: String = "Imports", as_text: Boolean): String = + val table = richTableLib.Table(title = title, highlight = true, show_lines = true) + table.add_column("File Name") + table.add_column("Import") + atom.imports.foreach { i => + table.add_row(i.file.name.l.mkString("\n"), i.importedEntity.getOrElse("")) + } + richConsole.print(table) + if as_text then richConsole.export_text().as[String] else "" + + def imports: String = imports("Imports", as_text = false) + + @Doc( + info = "List declarations", + longInfo = """|Lists the declarations by files from the loaded atom. |Requires the python modules to be installed. |""", - example = "sensitive" + example = "declarations" + ) + def declarations(title: String = "Declarations", as_text: Boolean): String = + val table = richTableLib.Table(title = title, highlight = true, show_lines = true) + table.add_column("File Name") + table.add_column("Declarations") + atom.file.whereNot(_.name("")).foreach { f => + val dec: Set[Declaration] = + (f.assignment.argument(1).filterNot( + _.code == "this" + ).isIdentifier.nameNot("tmp[0-9]+$").refsTo ++ f.method.parameter + .filterNot(_.code == "this") + .filter(_.typeFullName != "ANY")).toSet + table.add_row(f.name, dec.name.toSet.mkString("\n")) + } + richConsole.print(table) + if as_text then richConsole.export_text().as[String] else "" + end declarations + + def declarations: String = declarations("Declarations", as_text = false) + + @Doc( + info = "List sensitive literals", + longInfo = """|Lists the sensitive literals by files from the loaded atom. + |Requires the python modules to be installed. + |""", + example = "sensitive" + ) + def sensitive( + title: String = "Sensitive Literals", + pattern: String = "(secret|password|token|key|admin|root)", + as_text: Boolean = false + ): String = + val table = richTableLib.Table(title = title, highlight = true, show_lines = true) + table.add_column("File Name") + table.add_column("Sensitive Literals") + atom.file.whereNot(_.name("")).foreach { f => + val slits: Set[Literal] = + f.assignment.where(_.argument.order(1).code(s"(?i).*${pattern}.*")).argument.order( + 2 + ).isLiteral.toSet + table.add_row(f.name, if slits.nonEmpty then slits.code.mkString("\n") else "N/A") + } + richConsole.print(table) + if as_text then richConsole.export_text().as[String] else "" + end sensitive + + def sensitive: String = sensitive("Sensitive Literals", as_text = false) + + @Doc( + info = "Show graph edit distance from the source method to the comparison methods", + longInfo = """|Compute graph edit distance from the source method to the comparison methods. + |""", + example = "distance(source method iterator, comparison method iterators)" + ) + def distance(sourceTrav: Iterator[Method], sourceTravs: Iterator[Method]*): Seq[Double] = + val first_method = new MethodTraversal(sourceTrav.iterator).gml + sourceTravs.map { compareTrav => + val second_method = new MethodTraversal(compareTrav.iterator).gml + Torch.edit_distance(first_method, second_method) + } + + case class MethodDistance(filename: String, fullName: String, editDistance: Double) + + @Doc( + info = "Show methods similar to the given method", + longInfo = """|List methods to similar to the one based on graph edit distance. + |""", + example = "showSimilar(method full name)" + ) + def showSimilar( + methodFullName: String, + comparePattern: String = "", + upper_bound: Int = 500, + timeout: Int = 5, + as_text: Boolean = false + ): String = + val table = + richTableLib.Table( + title = s"Similarity analysis for `${methodFullName}`", + highlight = true, + show_lines = true + ) + val progress = richProgressLib.Progress(transient = true) + val first_method = atom.method.fullNameExact(methodFullName).gml + table.add_column("File Name") + table.add_column("Method Name") + table.add_column("Edit Distance") + val methodDistances = mutable.ArrayBuffer[MethodDistance]() + val base = if comparePattern.nonEmpty then atom.method.fullName(s".*${comparePattern}.*") + else atom.method.internal + base.whereNot(_.fullNameExact(methodFullName)).foreach { method => + py.`with`(progress) { mprogress => + val task = + mprogress.add_task(s"Analyzing ${method.fullName}", start = false, total = 100) + val edit_distance = + Torch.edit_distance( + first_method, + method.iterator.gml, + upper_bound = upper_bound, + timeout = timeout + ) + if edit_distance != -1 then + methodDistances += MethodDistance( + method.location.filename, + method.fullName, + edit_distance + ) + mprogress.stop_task(task) + mprogress.update(task, completed = 100) + } + } + methodDistances.sortInPlaceBy[Double](x => x.editDistance) + methodDistances.foreach(row => + table.add_row(row.filename, row.fullName, "" + row.editDistance) ) - def sensitive( - title: String = "Sensitive Literals", - pattern: String = "(secret|password|token|key|admin|root)", - as_text: Boolean = false - ): String = - val table = richTableLib.Table(title = title, highlight = true, show_lines = true) - table.add_column("File Name") - table.add_column("Sensitive Literals") - atom.file.whereNot(_.name("")).foreach { f => - val slits: Set[Literal] = - f.assignment.where(_.argument.order(1).code(s"(?i).*${pattern}.*")).argument.order( - 2 - ).isLiteral.toSet - table.add_row(f.name, if slits.nonEmpty then slits.code.mkString("\n") else "N/A") + richConsole.print(table) + if as_text then richConsole.export_text().as[String] else "" + end showSimilar + + def printDashes(count: Int) = + var tabStr = "+--- " + var i = 0 + while i < count do + tabStr = "| " + tabStr + i += 1 + tabStr + + @Doc( + info = "Show call tree for the given method", + longInfo = """|Show the call tree for the given method. + |""", + example = "callTree(method full name)" + ) + def callTree( + callerFullName: String, + tree: ListBuffer[String] = new ListBuffer[String](), + depth: Int = 3 + )(implicit atom: Cpg): ListBuffer[String] = + var dashCount = 0 + var lastCallerMethod = callerFullName + var lastDashCount = 0 + tree += callerFullName + + def findCallee(methodName: String, tree: ListBuffer[String]): ListBuffer[String] = + val calleeList = + atom.method.fullNameExact(methodName).callee.whereNot(_.name(".* + tree += s"${printDashes(dashCount)}${c.fullName}~~${c.location.filename}#${c.lineNumber.getOrElse(0)}" + findCallee(c.fullName, tree) } - richConsole.print(table) - if as_text then richConsole.export_text().as[String] else "" - end sensitive + tree - def sensitive: String = sensitive("Sensitive Literals", as_text = false) + findCallee(lastCallerMethod, tree) + end callTree - @Doc( - info = "Show graph edit distance from the source method to the comparison methods", - longInfo = """|Compute graph edit distance from the source method to the comparison methods. - |""", - example = "distance(source method iterator, comparison method iterators)" - ) - def distance(sourceTrav: Iterator[Method], sourceTravs: Iterator[Method]*): Seq[Double] = - val first_method = new MethodTraversal(sourceTrav.iterator).gml - sourceTravs.map { compareTrav => - val second_method = new MethodTraversal(compareTrav.iterator).gml - Torch.edit_distance(first_method, second_method) - } + def applyPostProcessingPasses(cpg: Cpg): Cpg = + new CpgGeneratorFactory(_config).forLanguage(cpg.metaData.language.l.head) match + case Some(frontend) => frontend.applyPostProcessingPasses(cpg) + case None => cpg - case class MethodDistance(filename: String, fullName: String, editDistance: Double) + def applyDefaultOverlays(cpg: Cpg): Cpg = + val appliedOverlays = Overlays.appliedOverlays(cpg) + if appliedOverlays.isEmpty then + report("Adding default overlays to base CPG") + _runAnalyzer(defaultOverlayCreators()*) + cpg - @Doc( - info = "Show methods similar to the given method", - longInfo = """|List methods to similar to the one based on graph edit distance. - |""", - example = "showSimilar(method full name)" - ) - def showSimilar( - methodFullName: String, - comparePattern: String = "", - upper_bound: Int = 500, - timeout: Int = 5, - as_text: Boolean = false - ): String = - val table = - richTableLib.Table( - title = s"Similarity analysis for `${methodFullName}`", - highlight = true, - show_lines = true - ) - val progress = richProgressLib.Progress(transient = true) - val first_method = atom.method.fullNameExact(methodFullName).gml - table.add_column("File Name") - table.add_column("Method Name") - table.add_column("Edit Distance") - val methodDistances = mutable.ArrayBuffer[MethodDistance]() - val base = if comparePattern.nonEmpty then atom.method.fullName(s".*${comparePattern}.*") - else atom.method.internal - base.whereNot(_.fullNameExact(methodFullName)).foreach { method => - py.`with`(progress) { mprogress => - val task = - mprogress.add_task(s"Analyzing ${method.fullName}", start = false, total = 100) - val edit_distance = - Torch.edit_distance( - first_method, - method.iterator.gml, - upper_bound = upper_bound, - timeout = timeout - ) - if edit_distance != -1 then - methodDistances += MethodDistance( - method.location.filename, - method.fullName, - edit_distance - ) - mprogress.stop_task(task) - mprogress.update(task, completed = 100) - } - } - methodDistances.sortInPlaceBy[Double](x => x.editDistance) - methodDistances.foreach(row => - table.add_row(row.filename, row.fullName, "" + row.editDistance) + def _runAnalyzer(overlayCreators: LayerCreator*): Cpg = + + overlayCreators.foreach { creator => + val overlayDirName = + workspace.getNextOverlayDirName(cpg, creator.overlayName) + + val projectOpt = workspace.projectByCpg(cpg) + if projectOpt.isEmpty then + throw new RuntimeException( + "No record for atom. Please use `importCode`/`importAtom/open`" ) - richConsole.print(table) - if as_text then richConsole.export_text().as[String] else "" - end showSimilar - - def printDashes(count: Int) = - var tabStr = "+--- " - var i = 0 - while i < count do - tabStr = "| " + tabStr - i += 1 - tabStr - - @Doc( - info = "Show call tree for the given method", - longInfo = """|Show the call tree for the given method. - |""", - example = "callTree(method full name)" + + if projectOpt.get.appliedOverlays.contains(creator.overlayName) then + report(s"Overlay ${creator.overlayName} already exists - skipping") + else + File(overlayDirName).createDirectories() + runCreator(creator, Some(overlayDirName)) + } + report( + "The graph has been modified. You may want to use the `save` command to persist changes to disk. All changes will also be saved collectively on exit" ) - def callTree( - callerFullName: String, - tree: ListBuffer[String] = new ListBuffer[String](), - depth: Int = 3 - )(implicit atom: Cpg): ListBuffer[String] = - var dashCount = 0 - var lastCallerMethod = callerFullName - var lastDashCount = 0 - tree += callerFullName - - def findCallee(methodName: String, tree: ListBuffer[String]): ListBuffer[String] = - val calleeList = - atom.method.fullNameExact(methodName).callee.whereNot(_.name(".* - tree += s"${printDashes(dashCount)}${c.fullName}~~${c.location.filename}#${c.lineNumber.getOrElse(0)}" - findCallee(c.fullName, tree) - } - tree - - findCallee(lastCallerMethod, tree) - end callTree - - def applyPostProcessingPasses(cpg: Cpg): Cpg = - new CpgGeneratorFactory(_config).forLanguage(cpg.metaData.language.l.head) match - case Some(frontend) => frontend.applyPostProcessingPasses(cpg) - case None => cpg - - def applyDefaultOverlays(cpg: Cpg): Cpg = - val appliedOverlays = Overlays.appliedOverlays(cpg) - if appliedOverlays.isEmpty then - report("Adding default overlays to base CPG") - _runAnalyzer(defaultOverlayCreators()*) - cpg - - def _runAnalyzer(overlayCreators: LayerCreator*): Cpg = - - overlayCreators.foreach { creator => - val overlayDirName = - workspace.getNextOverlayDirName(cpg, creator.overlayName) - - val projectOpt = workspace.projectByCpg(cpg) - if projectOpt.isEmpty then - throw new RuntimeException( - "No record for atom. Please use `importCode`/`importAtom/open`" - ) - - if projectOpt.get.appliedOverlays.contains(creator.overlayName) then - report(s"Overlay ${creator.overlayName} already exists - skipping") - else - File(overlayDirName).createDirectories() - runCreator(creator, Some(overlayDirName)) - } - report( - "The graph has been modified. You may want to use the `save` command to persist changes to disk. All changes will also be saved collectively on exit" - ) - cpg - end _runAnalyzer - - protected def runCreator(creator: LayerCreator, overlayDirName: Option[String]): Unit = - val context = new LayerCreatorContext(cpg, overlayDirName) - creator.run(context, storeUndoInfo = true) - - // We still tie the project name to the input path here - // if no project name has been provided. - - def fixProjectNameAndComplainOnFix(name: String): String = - val projectName = Some(name) - .filter(_.contains(java.io.File.separator)) - .map(x => deriveNameFromInputPath(x, workspace)) - .getOrElse(name) - if name != projectName then - System.err.println( - "Passing paths to `loadCpg` is deprecated, please use a project name" - ) - projectName + cpg + end _runAnalyzer + + protected def runCreator(creator: LayerCreator, overlayDirName: Option[String]): Unit = + val context = new LayerCreatorContext(cpg, overlayDirName) + creator.run(context, storeUndoInfo = true) + + // We still tie the project name to the input path here + // if no project name has been provided. + + def fixProjectNameAndComplainOnFix(name: String): String = + val projectName = Some(name) + .filter(_.contains(java.io.File.separator)) + .map(x => deriveNameFromInputPath(x, workspace)) + .getOrElse(name) + if name != projectName then + System.err.println( + "Passing paths to `loadCpg` is deprecated, please use a project name" + ) + projectName end Console object Console: - val nameOfLegacyCpgInProject = "app.atom" - - def deriveNameFromInputPath[T <: Project]( - inputPath: String, - workspace: WorkspaceManager[T] - ): String = - val name = File(inputPath).name - val project = workspace.project(name) - if project.isDefined && project.exists(_.inputPath != inputPath) then - var i = 1 - while workspace.project(name + i).isDefined do - i += 1 - name + i - else - name + val nameOfLegacyCpgInProject = "app.atom" + + def deriveNameFromInputPath[T <: Project]( + inputPath: String, + workspace: WorkspaceManager[T] + ): String = + val name = File(inputPath).name + val project = workspace.project(name) + if project.isDefined && project.exists(_.inputPath != inputPath) then + var i = 1 + while workspace.project(name + i).isDefined do + i += 1 + name + i + else + name class ConsoleException(message: String, cause: Option[Throwable]) extends RuntimeException(message, cause.orNull) with NoStackTrace: - def this(message: String) = this(message, None) - def this(message: String, cause: Throwable) = this(message, Option(cause)) + def this(message: String) = this(message, None) + def this(message: String, cause: Throwable) = this(message, Option(cause)) diff --git a/console/src/main/scala/io/appthreat/console/ConsoleConfig.scala b/console/src/main/scala/io/appthreat/console/ConsoleConfig.scala index b821ef3d..206ec61d 100644 --- a/console/src/main/scala/io/appthreat/console/ConsoleConfig.scala +++ b/console/src/main/scala/io/appthreat/console/ConsoleConfig.scala @@ -12,50 +12,50 @@ import scala.collection.mutable */ class InstallConfig(environment: Map[String, String] = sys.env): - /** determining the root path of the installation is rather complex unfortunately, because we - * support a variety of use cases: - * - running the installed distribution from the install dir - * - running the installed distribution anywhere else on the system - * - running a locally staged build (via `sbt stage` and then either `./chennai` or `cd - * platform/target/universal/stage; ./chennai`) - * - running a unit/integration test (note: the jars would be in the local cache, e.g. in - * ~/.coursier/cache) - */ - lazy val rootPath: File = - if environment.contains("CHEN_INSTALL_DIR") then - environment("CHEN_INSTALL_DIR").toFile - else - val uriToLibDir = - classOf[InstallConfig].getProtectionDomain.getCodeSource.getLocation.toURI - val pathToLibDir = File(uriToLibDir).parent - findRootDirectory(pathToLibDir).getOrElse { - val cwd = File.currentWorkingDirectory - findRootDirectory(cwd).getOrElse( - throw new AssertionError(s"""unable to find root installation directory + /** determining the root path of the installation is rather complex unfortunately, because we + * support a variety of use cases: + * - running the installed distribution from the install dir + * - running the installed distribution anywhere else on the system + * - running a locally staged build (via `sbt stage` and then either `./chennai` or `cd + * platform/target/universal/stage; ./chennai`) + * - running a unit/integration test (note: the jars would be in the local cache, e.g. in + * ~/.coursier/cache) + */ + lazy val rootPath: File = + if environment.contains("CHEN_INSTALL_DIR") then + environment("CHEN_INSTALL_DIR").toFile + else + val uriToLibDir = + classOf[InstallConfig].getProtectionDomain.getCodeSource.getLocation.toURI + val pathToLibDir = File(uriToLibDir).parent + findRootDirectory(pathToLibDir).getOrElse { + val cwd = File.currentWorkingDirectory + findRootDirectory(cwd).getOrElse( + throw new AssertionError(s"""unable to find root installation directory | context: tried to find marker file `$rootDirectoryMarkerFilename` | started search in both $pathToLibDir and $cwd and searched | $maxSearchDepth directories upwards""".stripMargin) - ) - } + ) + } - private val rootDirectoryMarkerFilename = ".installation_root" - private val maxSearchDepth = 10 + private val rootDirectoryMarkerFilename = ".installation_root" + private val maxSearchDepth = 10 - @tailrec - private def findRootDirectory( - currentSearchDir: File, - currentSearchDepth: Int = 0 - ): Option[File] = - if currentSearchDir.list.map(_.name).contains(rootDirectoryMarkerFilename) then - Some(currentSearchDir) - else if currentSearchDepth < maxSearchDepth && currentSearchDir.parentOption.isDefined then - findRootDirectory(currentSearchDir.parent) - else - None + @tailrec + private def findRootDirectory( + currentSearchDir: File, + currentSearchDepth: Int = 0 + ): Option[File] = + if currentSearchDir.list.map(_.name).contains(rootDirectoryMarkerFilename) then + Some(currentSearchDir) + else if currentSearchDepth < maxSearchDepth && currentSearchDir.parentOption.isDefined then + findRootDirectory(currentSearchDir.parent) + else + None end InstallConfig object InstallConfig: - def apply(): InstallConfig = new InstallConfig() + def apply(): InstallConfig = new InstallConfig() class ConsoleConfig( val install: InstallConfig = InstallConfig(), @@ -65,18 +65,18 @@ class ConsoleConfig( object ToolsConfig: - private val osSpecificOpenCmd: String = - if scala.util.Properties.isWin then "start" - else if scala.util.Properties.isMac then "open" - else "xdg-open" + private val osSpecificOpenCmd: String = + if scala.util.Properties.isWin then "start" + else if scala.util.Properties.isMac then "open" + else "xdg-open" - def apply(): ToolsConfig = new ToolsConfig() + def apply(): ToolsConfig = new ToolsConfig() class ToolsConfig(var imageViewer: String = ToolsConfig.osSpecificOpenCmd) class FrontendConfig(var cmdLineParams: Iterable[String] = mutable.Buffer()): - def withArgs(args: Iterable[String]): FrontendConfig = - new FrontendConfig(cmdLineParams ++ args) + def withArgs(args: Iterable[String]): FrontendConfig = + new FrontendConfig(cmdLineParams ++ args) object FrontendConfig: - def apply(): FrontendConfig = new FrontendConfig() + def apply(): FrontendConfig = new FrontendConfig() diff --git a/console/src/main/scala/io/appthreat/console/CpgConverter.scala b/console/src/main/scala/io/appthreat/console/CpgConverter.scala index d6ed755c..01d20653 100644 --- a/console/src/main/scala/io/appthreat/console/CpgConverter.scala +++ b/console/src/main/scala/io/appthreat/console/CpgConverter.scala @@ -5,8 +5,8 @@ import overflowdb.Config object CpgConverter: - def convertProtoCpgToOverflowDb(srcFilename: String, dstFilename: String): Unit = - val odbConfig = Config.withDefaults.withStorageLocation(dstFilename) - val config = - CpgLoaderConfig.withDefaults.doNotCreateIndexesOnLoad.withOverflowConfig(odbConfig) - CpgLoader.load(srcFilename, config).close + def convertProtoCpgToOverflowDb(srcFilename: String, dstFilename: String): Unit = + val odbConfig = Config.withDefaults.withStorageLocation(dstFilename) + val config = + CpgLoaderConfig.withDefaults.doNotCreateIndexesOnLoad.withOverflowConfig(odbConfig) + CpgLoader.load(srcFilename, config).close diff --git a/console/src/main/scala/io/appthreat/console/Help.scala b/console/src/main/scala/io/appthreat/console/Help.scala index de8290d3..e165ec9e 100644 --- a/console/src/main/scala/io/appthreat/console/Help.scala +++ b/console/src/main/scala/io/appthreat/console/Help.scala @@ -6,18 +6,18 @@ import overflowdb.traversal.help.{Table, DocFinder} object Help: - private val width = 80 + private val width = 80 - def overview(clazz: Class[?]): String = - val columnNames = List("command", "description", "example") - val rows = DocFinder - .findDocumentedMethodsOf(clazz) - .map { case StepDoc(_, funcName, doc) => - List(funcName, doc.info, doc.example) - } - .toList ++ List(runRow) + def overview(clazz: Class[?]): String = + val columnNames = List("command", "description", "example") + val rows = DocFinder + .findDocumentedMethodsOf(clazz) + .map { case StepDoc(_, funcName, doc) => + List(funcName, doc.info, doc.example) + } + .toList ++ List(runRow) - val header = formatNoQuotes(""" + val header = formatNoQuotes(""" | |Welcome to the interactive help system. Below you find |a table of all available top-level commands. To get @@ -29,40 +29,40 @@ object Help: | | |""".stripMargin) - header + "\n" + Table(columnNames, rows.sortBy(_.head)).render - end overview + header + "\n" + Table(columnNames, rows.sortBy(_.head)).render + end overview - def format(text: String): String = - "\"\"\"" + "\n" + formatNoQuotes(text) + "\"\"\"" + def format(text: String): String = + "\"\"\"" + "\n" + formatNoQuotes(text) + "\"\"\"" - def formatNoQuotes(text: String): String = - text.stripMargin - .split("\n\n") - .map(x => WordUtils.wrap(x.replace("\n", " "), width)) - .mkString("\n\n") - .trim + def formatNoQuotes(text: String): String = + text.stripMargin + .split("\n\n") + .map(x => WordUtils.wrap(x.replace("\n", " "), width)) + .mkString("\n\n") + .trim - private def runRow: List[String] = - List("run", "Run analyzer on active CPG", "run.securityprofile") + private def runRow: List[String] = + List("run", "Run analyzer on active CPG", "run.securityprofile") - // Since `run` is generated dynamically, it's not picked up when looking - // through methods via reflection, and therefore, we are adding - // it manually. - def runLongHelp: String = - Help.format(""" + // Since `run` is generated dynamically, it's not picked up when looking + // through methods via reflection, and therefore, we are adding + // it manually. + def runLongHelp: String = + Help.format(""" | |""".stripMargin) - def codeForHelpCommand(clazz: Class[?]): String = - val membersCode = DocFinder - .findDocumentedMethodsOf(clazz) - .map { case StepDoc(_, funcName, doc) => - s" val $funcName: String = ${Help.format(doc.longInfo)}" - } - .mkString("\n") + def codeForHelpCommand(clazz: Class[?]): String = + val membersCode = DocFinder + .findDocumentedMethodsOf(clazz) + .map { case StepDoc(_, funcName, doc) => + s" val $funcName: String = ${Help.format(doc.longInfo)}" + } + .mkString("\n") - val overview = Help.overview(clazz) - s""" + val overview = Help.overview(clazz) + s""" | class Helper() { | def run: String = Help.runLongHelp | override def toString: String = \"\"\"$overview\"\"\" @@ -72,5 +72,5 @@ object Help: | | val help = new Helper |""".stripMargin - end codeForHelpCommand + end codeForHelpCommand end Help diff --git a/console/src/main/scala/io/appthreat/console/JProduct.scala b/console/src/main/scala/io/appthreat/console/JProduct.scala index ab5765d2..188c2ef9 100644 --- a/console/src/main/scala/io/appthreat/console/JProduct.scala +++ b/console/src/main/scala/io/appthreat/console/JProduct.scala @@ -1,6 +1,6 @@ package io.appthreat.console sealed trait JProduct: - def name: String + def name: String case object ChenProduct extends JProduct: - val name: String = "chen" + val name: String = "chen" diff --git a/console/src/main/scala/io/appthreat/console/PluginManager.scala b/console/src/main/scala/io/appthreat/console/PluginManager.scala index 1b475c6a..48769bfc 100644 --- a/console/src/main/scala/io/appthreat/console/PluginManager.scala +++ b/console/src/main/scala/io/appthreat/console/PluginManager.scala @@ -19,78 +19,78 @@ import scala.util.{Failure, Success, Try} */ class PluginManager(val installDir: File): - /** Generate a sorted list of all installed plugins by examining the plugin directory. - */ - def listPlugins(): List[String] = - val installedPluginNames = pluginDir.toList - .flatMap { dir => - File(dir).list.toList.flatMap { f => - "^joernext-(.*?)-.*$".r.findAllIn(f.name).matchData.map { m => - m.group(1) - } + /** Generate a sorted list of all installed plugins by examining the plugin directory. + */ + def listPlugins(): List[String] = + val installedPluginNames = pluginDir.toList + .flatMap { dir => + File(dir).list.toList.flatMap { f => + "^joernext-(.*?)-.*$".r.findAllIn(f.name).matchData.map { m => + m.group(1) } } - .distinct - .sorted - installedPluginNames + } + .distinct + .sorted + installedPluginNames - /** Install the plugin stored at `filename`. The plugin is expected to be a zip file containing - * Java archives (.jar files). - */ - def add(filename: String): Unit = - if pluginDir.isEmpty then - println("Plugin directory does not exist") - return - val file = File(filename) - if !file.exists then - println(s"The file $filename does not exist") - else - addExisting(file) + /** Install the plugin stored at `filename`. The plugin is expected to be a zip file containing + * Java archives (.jar files). + */ + def add(filename: String): Unit = + if pluginDir.isEmpty then + println("Plugin directory does not exist") + return + val file = File(filename) + if !file.exists then + println(s"The file $filename does not exist") + else + addExisting(file) - private def addExisting(file: File): Unit = - val pluginName = file.name.replace(".zip", "") - val tmpDir = extractToTemporaryDir(file) - tmpDir.foreach(dir => addExistingUnzipped(dir, pluginName)) + private def addExisting(file: File): Unit = + val pluginName = file.name.replace(".zip", "") + val tmpDir = extractToTemporaryDir(file) + tmpDir.foreach(dir => addExistingUnzipped(dir, pluginName)) - private def addExistingUnzipped(file: File, pluginName: String): Unit = - file.listRecursively.filter(_.name.endsWith(".jar")).foreach { jar => - pluginDir.foreach { pDir => - if !(pDir / jar.name).exists then - val dstFileName = s"joernext-$pluginName-${jar.name}" - val dstFile = pDir / dstFileName - cp(jar, dstFile) - } - } + private def addExistingUnzipped(file: File, pluginName: String): Unit = + file.listRecursively.filter(_.name.endsWith(".jar")).foreach { jar => + pluginDir.foreach { pDir => + if !(pDir / jar.name).exists then + val dstFileName = s"joernext-$pluginName-${jar.name}" + val dstFile = pDir / dstFileName + cp(jar, dstFile) + } + } - private def extractToTemporaryDir(file: File) = - Try { file.unzip() } match - case Success(dir) => - Some(dir) - case Failure(exc) => - println("Error reading zip: " + exc.getMessage) - None + private def extractToTemporaryDir(file: File) = + Try { file.unzip() } match + case Success(dir) => + Some(dir) + case Failure(exc) => + println("Error reading zip: " + exc.getMessage) + None - /** Delete plugin with given `name` from the plugin directory. - */ - def rm(name: String): List[String] = - if !listPlugins().contains(name) then - List() - else - val filesToRemove = pluginDir.toList.flatMap { dir => - dir.list.filter { f => - f.name.startsWith(s"joernext-$name") - } + /** Delete plugin with given `name` from the plugin directory. + */ + def rm(name: String): List[String] = + if !listPlugins().contains(name) then + List() + else + val filesToRemove = pluginDir.toList.flatMap { dir => + dir.list.filter { f => + f.name.startsWith(s"joernext-$name") } - filesToRemove.foreach(f => f.delete()) - filesToRemove.map(_.pathAsString) + } + filesToRemove.foreach(f => f.delete()) + filesToRemove.map(_.pathAsString) - /** Return the path to the plugin directory or None if the plugin directory does not exist. - */ - def pluginDir: Option[Path] = - val pathToPluginDir = installDir.path.resolve("lib") - if pathToPluginDir.toFile.exists() then - Some(pathToPluginDir) - else - println(s"Plugin directory at $pathToPluginDir does not exist") - None + /** Return the path to the plugin directory or None if the plugin directory does not exist. + */ + def pluginDir: Option[Path] = + val pathToPluginDir = installDir.path.resolve("lib") + if pathToPluginDir.toFile.exists() then + Some(pathToPluginDir) + else + println(s"Plugin directory at $pathToPluginDir does not exist") + None end PluginManager diff --git a/console/src/main/scala/io/appthreat/console/Reporting.scala b/console/src/main/scala/io/appthreat/console/Reporting.scala index 08a15d2a..7c474ed5 100644 --- a/console/src/main/scala/io/appthreat/console/Reporting.scala +++ b/console/src/main/scala/io/appthreat/console/Reporting.scala @@ -5,11 +5,11 @@ import scala.collection.mutable trait Reporting: - def reportOutStream: OutputStream = System.err + def reportOutStream: OutputStream = System.err - def report(string: String): Unit = - reportOutStream.write((string + "\n").getBytes("UTF-8")) - GlobalReporting.appendToGlobalStdOut(string) + def report(string: String): Unit = + reportOutStream.write((string + "\n").getBytes("UTF-8")) + GlobalReporting.appendToGlobalStdOut(string) /** A dirty hack to capture the reported output for the server-mode. Context: server mode is a bit * tricky, because the reporting happens inside the repl, but we want to retrieve it from the @@ -20,23 +20,23 @@ trait Reporting: * UserRunnables concurrently. */ object GlobalReporting: - private var enabled = false + private var enabled = false - def enable(): Unit = - enabled = true + def enable(): Unit = + enabled = true - def isEnabled(): Boolean = enabled + def isEnabled(): Boolean = enabled - def disable(): Unit = - enabled = false + def disable(): Unit = + enabled = false - private val globalStdOut = new mutable.StringBuilder + private val globalStdOut = new mutable.StringBuilder - def appendToGlobalStdOut(s: String): Unit = - if enabled then globalStdOut.append(s + System.lineSeparator()) + def appendToGlobalStdOut(s: String): Unit = + if enabled then globalStdOut.append(s + System.lineSeparator()) - def getAndClearGlobalStdOut(): String = - val result = globalStdOut.result() - globalStdOut.clear() - result + def getAndClearGlobalStdOut(): String = + val result = globalStdOut.result() + globalStdOut.clear() + result end GlobalReporting diff --git a/console/src/main/scala/io/appthreat/console/Run.scala b/console/src/main/scala/io/appthreat/console/Run.scala index cf95a7b1..0bdae7ff 100644 --- a/console/src/main/scala/io/appthreat/console/Run.scala +++ b/console/src/main/scala/io/appthreat/console/Run.scala @@ -10,54 +10,54 @@ import scala.jdk.CollectionConverters.* object Run: - def runCustomQuery(console: Console[?], query: HasStoreMethod): Unit = - console._runAnalyzer(new LayerCreator: - override val overlayName: String = "custom" - override val description: String = "A custom pass" + def runCustomQuery(console: Console[?], query: HasStoreMethod): Unit = + console._runAnalyzer(new LayerCreator: + override val overlayName: String = "custom" + override val description: String = "A custom pass" - override def create(context: LayerCreatorContext, storeUndoInfo: Boolean): Unit = - val pass: CpgPass = new CpgPass(console.cpg): - override val name = "custom" - override def run(builder: DiffGraphBuilder): Unit = - query.store()(builder) - runPass(pass, context, storeUndoInfo) - ) + override def create(context: LayerCreatorContext, storeUndoInfo: Boolean): Unit = + val pass: CpgPass = new CpgPass(console.cpg): + override val name = "custom" + override def run(builder: DiffGraphBuilder): Unit = + query.store()(builder) + runPass(pass, context, storeUndoInfo) + ) - /** Generate code for the run command - * @param exclude - * list of analyzers to exclude (by full class name) - */ - def codeForRunCommand(exclude: List[String] = List()): String = - val ioSlLayerCreators = creatorsFor("io.shiftleft", exclude) - val ioJoernLayerCreators = creatorsFor("io.appthreat", exclude) - codeForLayerCreators((ioSlLayerCreators ++ ioJoernLayerCreators).distinct) + /** Generate code for the run command + * @param exclude + * list of analyzers to exclude (by full class name) + */ + def codeForRunCommand(exclude: List[String] = List()): String = + val ioSlLayerCreators = creatorsFor("io.shiftleft", exclude) + val ioJoernLayerCreators = creatorsFor("io.appthreat", exclude) + codeForLayerCreators((ioSlLayerCreators ++ ioJoernLayerCreators).distinct) - private def creatorsFor(namespace: String, exclude: List[String]) = - new Reflections( - new ConfigurationBuilder().setUrls( - ClasspathHelper.forPackage( - namespace, - ClasspathHelper.contextClassLoader(), - ClasspathHelper.staticClassLoader() - ) + private def creatorsFor(namespace: String, exclude: List[String]) = + new Reflections( + new ConfigurationBuilder().setUrls( + ClasspathHelper.forPackage( + namespace, + ClasspathHelper.contextClassLoader(), + ClasspathHelper.staticClassLoader() + ) + ) + ).getSubTypesOf(classOf[LayerCreator]) + .asScala + .filterNot(t => + t.isAnonymousClass || t.isLocalClass || t.isMemberClass || t.isSynthetic ) - ).getSubTypesOf(classOf[LayerCreator]) - .asScala - .filterNot(t => - t.isAnonymousClass || t.isLocalClass || t.isMemberClass || t.isSynthetic - ) - .filterNot(t => t.getName.startsWith("io.appthreat.console.Run")) - .toList - .map(t => (t.getSimpleName.toLowerCase, s"_root_.${t.getName}")) - .filter(t => !exclude.contains(t._2)) + .filterNot(t => t.getName.startsWith("io.appthreat.console.Run")) + .toList + .map(t => (t.getSimpleName.toLowerCase, s"_root_.${t.getName}")) + .filter(t => !exclude.contains(t._2)) - private def codeForLayerCreators(layerCreatorTypeNames: List[(String, String)]): String = - val optsMembersCode = layerCreatorTypeNames - .map { case (varName, typeName) => s"val $varName = $typeName.defaultOpts" } - .mkString("\n") + private def codeForLayerCreators(layerCreatorTypeNames: List[(String, String)]): String = + val optsMembersCode = layerCreatorTypeNames + .map { case (varName, typeName) => s"val $varName = $typeName.defaultOpts" } + .mkString("\n") - val optsCode = - s""" + val optsCode = + s""" |class OptsDynamic { |$optsMembersCode |} @@ -69,27 +69,27 @@ object Run: | def diffGraph = _diffGraph |""".stripMargin - val membersCode = layerCreatorTypeNames - .map { case (varName, typeName) => - s" def $varName: Cpg = _runAnalyzer(new $typeName(opts.$varName))" - } - .mkString("\n") + val membersCode = layerCreatorTypeNames + .map { case (varName, typeName) => + s" def $varName: Cpg = _runAnalyzer(new $typeName(opts.$varName))" + } + .mkString("\n") - val toStringCode = - s""" + val toStringCode = + s""" | import overflowdb.traversal.help.Table | override def toString() : String = { | val columnNames = List("name", "description") | val rows = | ${layerCreatorTypeNames.map { case (varName, typeName) => - s"""List("$varName",$typeName.description.trim)""" - }} + s"""List("$varName",$typeName.description.trim)""" + }} | "\\n" + Table(columnNames, rows).render | } |""".stripMargin - optsCode + - s""" + optsCode + + s""" |class OverlaysDynamic { | | def apply(query: _root_.io.shiftleft.semanticcpg.language.HasStoreMethod) = @@ -101,5 +101,5 @@ object Run: |} |val run = new OverlaysDynamic() |""".stripMargin - end codeForLayerCreators + end codeForLayerCreators end Run diff --git a/console/src/main/scala/io/appthreat/console/cpgcreation/AtomGenerator.scala b/console/src/main/scala/io/appthreat/console/cpgcreation/AtomGenerator.scala index ed3246e9..98649c4e 100644 --- a/console/src/main/scala/io/appthreat/console/cpgcreation/AtomGenerator.scala +++ b/console/src/main/scala/io/appthreat/console/cpgcreation/AtomGenerator.scala @@ -14,43 +14,43 @@ case class AtomGenerator( sliceMode: String = "reachables", slicesFile: String = "reachables.slices.json" ) extends CpgGenerator: - private lazy val command: String = sys.env.getOrElse("ATOM_CMD", "atom") - private lazy val cdxgenCommand: String = sys.env.getOrElse("CDXGEN_CMD", "cdxgen") + private lazy val command: String = sys.env.getOrElse("ATOM_CMD", "atom") + private lazy val cdxgenCommand: String = sys.env.getOrElse("CDXGEN_CMD", "cdxgen") - /** Generate an atom for the given input path. Returns the output path, or None, if no CPG was - * generated. - */ - override def generate(inputPath: String, outputPath: String = "app.atom"): Try[String] = - // If there is no bom.json file in the root directory, attempt to automatically invoke cdxgen - val bomPath = File(inputPath) / "bom.json" - if !bomPath.exists then - val cdxLanguage = language.toLowerCase().replace("src", "") - val arguments = Seq( - "-t", - cdxLanguage, - "--deep", - "-o", - (File(inputPath) / "bom.json").pathAsString, - inputPath - ) - runShellCommand(cdxgenCommand, arguments) - val arguments = Seq( - sliceMode, - "-s", - (File(inputPath) / slicesFile).pathAsString, - "--output", - (File(inputPath) / outputPath).pathAsString, - "--language", - language, - inputPath - ) ++ config.cmdLineParams - runShellCommand(command, arguments).map(_ => (File(inputPath) / outputPath).pathAsString) - end generate + /** Generate an atom for the given input path. Returns the output path, or None, if no CPG was + * generated. + */ + override def generate(inputPath: String, outputPath: String = "app.atom"): Try[String] = + // If there is no bom.json file in the root directory, attempt to automatically invoke cdxgen + val bomPath = File(inputPath) / "bom.json" + if !bomPath.exists then + val cdxLanguage = language.toLowerCase().replace("src", "") + val arguments = Seq( + "-t", + cdxLanguage, + "--deep", + "-o", + (File(inputPath) / "bom.json").pathAsString, + inputPath + ) + runShellCommand(cdxgenCommand, arguments) + val arguments = Seq( + sliceMode, + "-s", + (File(inputPath) / slicesFile).pathAsString, + "--output", + (File(inputPath) / outputPath).pathAsString, + "--language", + language, + inputPath + ) ++ config.cmdLineParams + runShellCommand(command, arguments).map(_ => (File(inputPath) / outputPath).pathAsString) + end generate - override def isAvailable: Boolean = true + override def isAvailable: Boolean = true - override def applyPostProcessingPasses(atom: Cpg): Cpg = - atom + override def applyPostProcessingPasses(atom: Cpg): Cpg = + atom - override def isJvmBased = false + override def isJvmBased = false end AtomGenerator diff --git a/console/src/main/scala/io/appthreat/console/cpgcreation/CCpgGenerator.scala b/console/src/main/scala/io/appthreat/console/cpgcreation/CCpgGenerator.scala index 10de20c9..d535e1a5 100644 --- a/console/src/main/scala/io/appthreat/console/cpgcreation/CCpgGenerator.scala +++ b/console/src/main/scala/io/appthreat/console/cpgcreation/CCpgGenerator.scala @@ -9,17 +9,17 @@ import scala.util.Try * Eclipse CDT parsing / preprocessing. */ case class CCpgGenerator(config: FrontendConfig, rootPath: Path) extends CpgGenerator: - private lazy val command: Path = - if isWin then rootPath.resolve("c2cpg.bat") else rootPath.resolve("c2cpg.sh") + private lazy val command: Path = + if isWin then rootPath.resolve("c2cpg.bat") else rootPath.resolve("c2cpg.sh") - /** Generate a CPG for the given input path. Returns the output path, or None, if no CPG was - * generated. - */ - override def generate(inputPath: String, outputPath: String = "cpg.bin"): Try[String] = - val arguments = Seq(inputPath, "--output", outputPath) ++ config.cmdLineParams - runShellCommand(command.toString, arguments).map(_ => outputPath) + /** Generate a CPG for the given input path. Returns the output path, or None, if no CPG was + * generated. + */ + override def generate(inputPath: String, outputPath: String = "cpg.bin"): Try[String] = + val arguments = Seq(inputPath, "--output", outputPath) ++ config.cmdLineParams + runShellCommand(command.toString, arguments).map(_ => outputPath) - override def isAvailable: Boolean = - command.toFile.exists + override def isAvailable: Boolean = + command.toFile.exists - override def isJvmBased = true + override def isJvmBased = true diff --git a/console/src/main/scala/io/appthreat/console/cpgcreation/CdxGenerator.scala b/console/src/main/scala/io/appthreat/console/cpgcreation/CdxGenerator.scala index c1b4994f..fbe60ad3 100644 --- a/console/src/main/scala/io/appthreat/console/cpgcreation/CdxGenerator.scala +++ b/console/src/main/scala/io/appthreat/console/cpgcreation/CdxGenerator.scala @@ -10,24 +10,24 @@ import better.files.File.LinkOptions case class CdxGenerator(config: FrontendConfig, rootPath: Path, language: String) extends CpgGenerator: - private lazy val command: String = "cdxgen" + private lazy val command: String = "cdxgen" - /** Generate a CycloneDX BoM for the given input path. Returns the output path, or None, if no - * cdx was generated. - */ - override def generate(inputPath: String, outputPath: String = "bom.json"): Try[String] = - var outFile = - if File(outputPath).isDirectory(linkOptions = LinkOptions.noFollow) then - (File(outputPath) / "bom.json").pathAsString - else outputPath - val arguments = - Seq("-o", outFile, "-t", language, "--deep", inputPath) ++ config.cmdLineParams - runShellCommand(command, arguments).map(_ => outFile) + /** Generate a CycloneDX BoM for the given input path. Returns the output path, or None, if no cdx + * was generated. + */ + override def generate(inputPath: String, outputPath: String = "bom.json"): Try[String] = + var outFile = + if File(outputPath).isDirectory(linkOptions = LinkOptions.noFollow) then + (File(outputPath) / "bom.json").pathAsString + else outputPath + val arguments = + Seq("-o", outFile, "-t", language, "--deep", inputPath) ++ config.cmdLineParams + runShellCommand(command, arguments).map(_ => outFile) - override def isAvailable: Boolean = true + override def isAvailable: Boolean = true - override def applyPostProcessingPasses(atom: Cpg): Cpg = - atom + override def applyPostProcessingPasses(atom: Cpg): Cpg = + atom - override def isJvmBased = false + override def isJvmBased = false end CdxGenerator diff --git a/console/src/main/scala/io/appthreat/console/cpgcreation/CpgGenerator.scala b/console/src/main/scala/io/appthreat/console/cpgcreation/CpgGenerator.scala index caaef721..e0c6b8f1 100644 --- a/console/src/main/scala/io/appthreat/console/cpgcreation/CpgGenerator.scala +++ b/console/src/main/scala/io/appthreat/console/cpgcreation/CpgGenerator.scala @@ -12,36 +12,36 @@ import scala.util.Try */ abstract class CpgGenerator(): - def isWin: Boolean = scala.util.Properties.isWin - - def isAvailable: Boolean - - /** is this a JVM based frontend? if so, we'll invoke it with -Xmx for max heap settings */ - def isJvmBased: Boolean - - /** Generate a CPG for the given input path. Returns the output path, or a Failure, if no CPG - * was generated. - * - * This method appends command line options in config.frontend.cmdLineParams to the shell - * command. - */ - def generate(inputPath: String, outputPath: String = "app.atom"): Try[String] - - protected def runShellCommand(program: String, arguments: Seq[String]): Try[Unit] = - Try { - val cmd = Seq(program) ++ performanceParameter ++ arguments - val exitValue = cmd.run().exitValue() - assert(exitValue == 0, s"Error running shell command: exitValue=$exitValue; $cmd") - } - - protected lazy val performanceParameter = - if isJvmBased then - val maxValueInGigabytes = - Math.floor(Runtime.getRuntime.maxMemory.toDouble / 1024 / 1024 / 1024).toInt - Seq(s"-J-Xmx${maxValueInGigabytes}G") - else Nil - - /** override in specific cpg generators to make them apply post processing passes */ - def applyPostProcessingPasses(cpg: Cpg): Cpg = - cpg + def isWin: Boolean = scala.util.Properties.isWin + + def isAvailable: Boolean + + /** is this a JVM based frontend? if so, we'll invoke it with -Xmx for max heap settings */ + def isJvmBased: Boolean + + /** Generate a CPG for the given input path. Returns the output path, or a Failure, if no CPG was + * generated. + * + * This method appends command line options in config.frontend.cmdLineParams to the shell + * command. + */ + def generate(inputPath: String, outputPath: String = "app.atom"): Try[String] + + protected def runShellCommand(program: String, arguments: Seq[String]): Try[Unit] = + Try { + val cmd = Seq(program) ++ performanceParameter ++ arguments + val exitValue = cmd.run().exitValue() + assert(exitValue == 0, s"Error running shell command: exitValue=$exitValue; $cmd") + } + + protected lazy val performanceParameter = + if isJvmBased then + val maxValueInGigabytes = + Math.floor(Runtime.getRuntime.maxMemory.toDouble / 1024 / 1024 / 1024).toInt + Seq(s"-J-Xmx${maxValueInGigabytes}G") + else Nil + + /** override in specific cpg generators to make them apply post processing passes */ + def applyPostProcessingPasses(cpg: Cpg): Cpg = + cpg end CpgGenerator diff --git a/console/src/main/scala/io/appthreat/console/cpgcreation/CpgGeneratorFactory.scala b/console/src/main/scala/io/appthreat/console/cpgcreation/CpgGeneratorFactory.scala index 86b9866a..58e45f77 100644 --- a/console/src/main/scala/io/appthreat/console/cpgcreation/CpgGeneratorFactory.scala +++ b/console/src/main/scala/io/appthreat/console/cpgcreation/CpgGeneratorFactory.scala @@ -11,84 +11,84 @@ import java.nio.file.Path import scala.util.Try object CpgGeneratorFactory: - private val KNOWN_LANGUAGES = Set( - Languages.C, - Languages.CSHARP, - Languages.GOLANG, - Languages.GHIDRA, - Languages.JAVA, - Languages.JAVASCRIPT, - Languages.JSSRC, - Languages.PYTHON, - Languages.PYTHONSRC, - Languages.LLVM, - Languages.PHP, - Languages.KOTLIN, - Languages.NEWC, - Languages.JAVASRC - ) + private val KNOWN_LANGUAGES = Set( + Languages.C, + Languages.CSHARP, + Languages.GOLANG, + Languages.GHIDRA, + Languages.JAVA, + Languages.JAVASCRIPT, + Languages.JSSRC, + Languages.PYTHON, + Languages.PYTHONSRC, + Languages.LLVM, + Languages.PHP, + Languages.KOTLIN, + Languages.NEWC, + Languages.JAVASRC + ) class CpgGeneratorFactory(config: ConsoleConfig): - /** For a given input path, try to guess a suitable generator and return it - */ - def forCodeAt(inputPath: String): Option[CpgGenerator] = - for - language <- guessLanguage(inputPath) - cpgGenerator <- cpgGeneratorForLanguage( - language, - config.frontend, - config.install.rootPath.path, - args = Nil - ) - yield cpgGenerator + /** For a given input path, try to guess a suitable generator and return it + */ + def forCodeAt(inputPath: String): Option[CpgGenerator] = + for + language <- guessLanguage(inputPath) + cpgGenerator <- cpgGeneratorForLanguage( + language, + config.frontend, + config.install.rootPath.path, + args = Nil + ) + yield cpgGenerator - /** For a language, return the generator - */ - def forLanguage(language: String): Option[CpgGenerator] = - Option(language.toUpperCase()) - .filter(languageIsKnown) - .flatMap { lang => - cpgGeneratorForLanguage( - lang, - config.frontend, - config.install.rootPath.path, - args = Nil - ) - } + /** For a language, return the generator + */ + def forLanguage(language: String): Option[CpgGenerator] = + Option(language.toUpperCase()) + .filter(languageIsKnown) + .flatMap { lang => + cpgGeneratorForLanguage( + lang, + config.frontend, + config.install.rootPath.path, + args = Nil + ) + } - def languageIsKnown(language: String): Boolean = - CpgGeneratorFactory.KNOWN_LANGUAGES.contains(language) + def languageIsKnown(language: String): Boolean = + CpgGeneratorFactory.KNOWN_LANGUAGES.contains(language) - def runGenerator(generator: CpgGenerator, inputPath: String, outputPath: String): Try[Path] = - val outputFileOpt: Try[File] = - generator.generate(inputPath, outputPath).map(File(_)) - outputFileOpt.map { outFile => - val parentPath = outFile.parent.path.toAbsolutePath - if isZipFile(outFile) then - val srcFilename = outFile.path.toAbsolutePath.toString - val dstFilename = parentPath.resolve("cpg.bin").toAbsolutePath.toString - // MemoryHelper.hintForInsufficientMemory(srcFilename).map(report) - convertProtoCpgToOverflowDb(srcFilename, dstFilename) - else - val srcPath = parentPath.resolve("app.atom") - if srcPath.toFile.exists() then - mv(srcPath, parentPath.resolve("cpg.bin")) - parentPath - } + def runGenerator(generator: CpgGenerator, inputPath: String, outputPath: String): Try[Path] = + val outputFileOpt: Try[File] = + generator.generate(inputPath, outputPath).map(File(_)) + outputFileOpt.map { outFile => + val parentPath = outFile.parent.path.toAbsolutePath + if isZipFile(outFile) then + val srcFilename = outFile.path.toAbsolutePath.toString + val dstFilename = parentPath.resolve("cpg.bin").toAbsolutePath.toString + // MemoryHelper.hintForInsufficientMemory(srcFilename).map(report) + convertProtoCpgToOverflowDb(srcFilename, dstFilename) + else + val srcPath = parentPath.resolve("app.atom") + if srcPath.toFile.exists() then + mv(srcPath, parentPath.resolve("cpg.bin")) + parentPath + } - def convertProtoCpgToOverflowDb(srcFilename: String, dstFilename: String): Unit = - val odbConfig = Config.withDefaults.withStorageLocation(dstFilename) - val config = - CpgLoaderConfig.withDefaults.doNotCreateIndexesOnLoad.withOverflowConfig(odbConfig) - CpgLoader.load(srcFilename, config).close - File(srcFilename).delete() + def convertProtoCpgToOverflowDb(srcFilename: String, dstFilename: String): Unit = + val odbConfig = Config.withDefaults.withStorageLocation(dstFilename) + val config = + CpgLoaderConfig.withDefaults.doNotCreateIndexesOnLoad.withOverflowConfig(odbConfig) + CpgLoader.load(srcFilename, config).close + File(srcFilename).delete() - def isZipFile(file: File): Boolean = - val bytes = file.bytes - Try { - bytes.next() == 'P' && bytes.next() == 'K' - }.getOrElse(false) + def isZipFile(file: File): Boolean = + val bytes = file.bytes + Try { + bytes.next() == 'P' && bytes.next() == 'K' + }.getOrElse(false) - private def report(str: String): Unit = System.err.println(str) + private def report(str: String): Unit = System.err.println(str) end CpgGeneratorFactory diff --git a/console/src/main/scala/io/appthreat/console/cpgcreation/ImportCode.scala b/console/src/main/scala/io/appthreat/console/cpgcreation/ImportCode.scala index 008ba131..1f802e12 100644 --- a/console/src/main/scala/io/appthreat/console/cpgcreation/ImportCode.scala +++ b/console/src/main/scala/io/appthreat/console/cpgcreation/ImportCode.scala @@ -12,214 +12,214 @@ import java.nio.file.Path import scala.util.{Failure, Success, Try} class ImportCode[T <: Project](console: Console[T]) extends Reporting: - import Console.* - - private val config = console.config - private val workspace = console.workspace - protected val generatorFactory = new CpgGeneratorFactory(config) - val chenPyUtils = py.module("chenpy.utils") - - private def checkInputPath(inputPath: String): Unit = - if !File(inputPath).exists then - throw new ConsoleException(s"Input path does not exist: '$inputPath'") - - def importUrl(inputPath: String): String = chenPyUtils.import_url(inputPath).as[String] - - /** This is the `importCode(...)` method exposed on the console. It attempts to find a suitable - * CPG generator first by looking at the `language` parameter and if no generator is found for - * the language, looking the contents at `inputPath` to determine heuristically which generator - * to use. - */ - def apply(inputPath: String, projectName: String = "", language: String = ""): Cpg = - var srcPath = - if - inputPath.startsWith("http") || inputPath.startsWith( - "git://" - ) || inputPath.startsWith("CVE-") || inputPath - .startsWith("GHSA-") - then importUrl(inputPath) - else inputPath - checkInputPath(srcPath) - if language != "" then - generatorFactory.forLanguage(language.toUpperCase()) match - case None => - throw new ConsoleException(s"No Atom generator exists for language: $language") - case Some(frontend) => apply(frontend, srcPath, projectName) - else - generatorFactory.forCodeAt(srcPath) match - case None => - throw new ConsoleException(s"No suitable Atom generator found for: $srcPath") - case Some(frontend) => apply(frontend, srcPath, projectName) - end apply - - def c: SourceBasedFrontend = new CFrontend("c") - def cpp: SourceBasedFrontend = new CFrontend("cpp", extension = "cpp") - def java: SourceBasedFrontend = - new SourceBasedFrontend("java", Languages.JAVASRC, "Java Source Frontend", "java") - def jvm: Frontend = - new BinaryFrontend( - "jvm", - Languages.JAVA, - "Java/Dalvik Bytecode Frontend (based on SOOT's jimple)" - ) - def ghidra: Frontend = - new BinaryFrontend("ghidra", Languages.GHIDRA, "ghidra reverse engineering frontend") - def kotlin: SourceBasedFrontend = - new SourceBasedFrontend("kotlin", Languages.KOTLIN, "Kotlin Source Frontend", "kotlin") - def python: SourceBasedFrontend = - new SourceBasedFrontend("python", Languages.PYTHONSRC, "Python Source Frontend", "py") - def golang: SourceBasedFrontend = - new SourceBasedFrontend("golang", Languages.GOLANG, "Golang Source Frontend", "go") - def javascript: SourceBasedFrontend = - new SourceBasedFrontend( - "javascript", - Languages.JAVASCRIPT, - "Javascript Source Frontend", - "js" - ) - def jssrc: SourceBasedFrontend = - new SourceBasedFrontend( - "jssrc", - Languages.JSSRC, - "Javascript/Typescript Source Frontend based on astgen", - "js" - ) - def csharp: Frontend = - new BinaryFrontend("csharp", Languages.CSHARP, "C# Source Frontend (Roslyn)") - def llvm: Frontend = new BinaryFrontend("llvm", Languages.LLVM, "LLVM Bitcode Frontend") - def php: SourceBasedFrontend = - new SourceBasedFrontend("php", Languages.PHP, "PHP source frontend", "php") - def ruby: SourceBasedFrontend = - new SourceBasedFrontend("ruby", Languages.RUBYSRC, "Ruby source frontend", "rb") - - private def allFrontends: List[Frontend] = - List( - c, - cpp, - ghidra, - kotlin, - java, - jvm, - javascript, - jssrc, - golang, - llvm, - php, - python, - csharp, - ruby + import Console.* + + private val config = console.config + private val workspace = console.workspace + protected val generatorFactory = new CpgGeneratorFactory(config) + val chenPyUtils = py.module("chenpy.utils") + + private def checkInputPath(inputPath: String): Unit = + if !File(inputPath).exists then + throw new ConsoleException(s"Input path does not exist: '$inputPath'") + + def importUrl(inputPath: String): String = chenPyUtils.import_url(inputPath).as[String] + + /** This is the `importCode(...)` method exposed on the console. It attempts to find a suitable + * CPG generator first by looking at the `language` parameter and if no generator is found for + * the language, looking the contents at `inputPath` to determine heuristically which generator + * to use. + */ + def apply(inputPath: String, projectName: String = "", language: String = ""): Cpg = + var srcPath = + if + inputPath.startsWith("http") || inputPath.startsWith( + "git://" + ) || inputPath.startsWith("CVE-") || inputPath + .startsWith("GHSA-") + then importUrl(inputPath) + else inputPath + checkInputPath(srcPath) + if language != "" then + generatorFactory.forLanguage(language.toUpperCase()) match + case None => + throw new ConsoleException(s"No Atom generator exists for language: $language") + case Some(frontend) => apply(frontend, srcPath, projectName) + else + generatorFactory.forCodeAt(srcPath) match + case None => + throw new ConsoleException(s"No suitable Atom generator found for: $srcPath") + case Some(frontend) => apply(frontend, srcPath, projectName) + end apply + + def c: SourceBasedFrontend = new CFrontend("c") + def cpp: SourceBasedFrontend = new CFrontend("cpp", extension = "cpp") + def java: SourceBasedFrontend = + new SourceBasedFrontend("java", Languages.JAVASRC, "Java Source Frontend", "java") + def jvm: Frontend = + new BinaryFrontend( + "jvm", + Languages.JAVA, + "Java/Dalvik Bytecode Frontend (based on SOOT's jimple)" + ) + def ghidra: Frontend = + new BinaryFrontend("ghidra", Languages.GHIDRA, "ghidra reverse engineering frontend") + def kotlin: SourceBasedFrontend = + new SourceBasedFrontend("kotlin", Languages.KOTLIN, "Kotlin Source Frontend", "kotlin") + def python: SourceBasedFrontend = + new SourceBasedFrontend("python", Languages.PYTHONSRC, "Python Source Frontend", "py") + def golang: SourceBasedFrontend = + new SourceBasedFrontend("golang", Languages.GOLANG, "Golang Source Frontend", "go") + def javascript: SourceBasedFrontend = + new SourceBasedFrontend( + "javascript", + Languages.JAVASCRIPT, + "Javascript Source Frontend", + "js" + ) + def jssrc: SourceBasedFrontend = + new SourceBasedFrontend( + "jssrc", + Languages.JSSRC, + "Javascript/Typescript Source Frontend based on astgen", + "js" + ) + def csharp: Frontend = + new BinaryFrontend("csharp", Languages.CSHARP, "C# Source Frontend (Roslyn)") + def llvm: Frontend = new BinaryFrontend("llvm", Languages.LLVM, "LLVM Bitcode Frontend") + def php: SourceBasedFrontend = + new SourceBasedFrontend("php", Languages.PHP, "PHP source frontend", "php") + def ruby: SourceBasedFrontend = + new SourceBasedFrontend("ruby", Languages.RUBYSRC, "Ruby source frontend", "rb") + + private def allFrontends: List[Frontend] = + List( + c, + cpp, + ghidra, + kotlin, + java, + jvm, + javascript, + jssrc, + golang, + llvm, + php, + python, + csharp, + ruby + ) + + // this is only abstract to force people adding frontends to make a decision whether the frontend consumes binaries or source + abstract class Frontend(val name: String, val language: String, val description: String = ""): + def cpgGeneratorForLanguage( + language: String, + config: FrontendConfig, + rootPath: Path, + args: List[String] + ): Option[CpgGenerator] = + io.appthreat.console.cpgcreation.cpgGeneratorForLanguage( + language, + config, + rootPath, + args ) - // this is only abstract to force people adding frontends to make a decision whether the frontend consumes binaries or source - abstract class Frontend(val name: String, val language: String, val description: String = ""): - def cpgGeneratorForLanguage( - language: String, - config: FrontendConfig, - rootPath: Path, - args: List[String] - ): Option[CpgGenerator] = - io.appthreat.console.cpgcreation.cpgGeneratorForLanguage( - language, - config, - rootPath, - args - ) - - def isAvailable: Boolean = - cpgGeneratorForLanguage( - language, - config.frontend, - config.install.rootPath.path, - args = Nil - ).exists(_.isAvailable) - - def apply(inputPath: String, projectName: String = "", args: List[String] = List()): Cpg = - val frontend = cpgGeneratorForLanguage( - language, - config.frontend, - config.install.rootPath.path, - args + def isAvailable: Boolean = + cpgGeneratorForLanguage( + language, + config.frontend, + config.install.rootPath.path, + args = Nil + ).exists(_.isAvailable) + + def apply(inputPath: String, projectName: String = "", args: List[String] = List()): Cpg = + val frontend = cpgGeneratorForLanguage( + language, + config.frontend, + config.install.rootPath.path, + args + ) + .getOrElse(throw new ConsoleException( + s"no atom generator for language=$language available!" + )) + new ImportCode(console)(frontend, inputPath, projectName) + end Frontend + + private class BinaryFrontend(name: String, language: String, description: String = "") + extends Frontend(name, language, description) + + class SourceBasedFrontend( + name: String, + language: String, + description: String, + extension: String + ) extends Frontend(name, language, description): + + def fromString(str: String, args: List[String] = List()): Cpg = + withCodeInTmpFile(str, "tmp." + extension) { dir => + super.apply(dir.path.toString, args = args) + } match + case Failure(exception) => throw new ConsoleException( + s"unable to generate atom from given String", + exception + ) + case Success(value) => value + class CFrontend(name: String, extension: String = "c") + extends SourceBasedFrontend( + name, + Languages.NEWC, + "Eclipse CDT Based Frontend for C/C++", + extension + ) + + private def withCodeInTmpFile(str: String, filename: String)(f: File => Cpg): Try[Cpg] = + val dir = File.newTemporaryDirectory("console") + val result = Try { + (dir / filename).write(str) + f(dir) + } + dir.deleteOnExit(swallowIOExceptions = true) + result + + /** Provide an overview of the available CPG generators (frontends) + */ + override def toString: String = + val cols = List("name", "description", "available") + val rows = allFrontends.map { frontend => + List(frontend.name, frontend.description, frontend.isAvailable.toString) + } + "Type `importCode.` to run a specific language frontend\n" + + "\n" + Table(cols, rows).render + + private def apply(generator: CpgGenerator, inputPath: String, projectName: String): Cpg = + checkInputPath(inputPath) + + val name = Option(projectName).filter(_.nonEmpty).getOrElse(deriveNameFromInputPath( + inputPath, + workspace + )) + report(s"Creating project `$name` for code at `$inputPath`") + + val cpgMaybe = workspace.createProject(inputPath, name).flatMap { pathToProject => + val frontendCpgOutFile = pathToProject.resolve(nameOfLegacyCpgInProject) + val frontendAtomPath = + generatorFactory.runGenerator(generator, inputPath, frontendCpgOutFile.toString) + frontendAtomPath match + case Success(_) => + console.open(name).flatMap(_.cpg) + case Failure(exception) => + throw new ConsoleException( + s"Error creating project for input path: `$inputPath`", + exception ) - .getOrElse(throw new ConsoleException( - s"no atom generator for language=$language available!" - )) - new ImportCode(console)(frontend, inputPath, projectName) - end Frontend - - private class BinaryFrontend(name: String, language: String, description: String = "") - extends Frontend(name, language, description) + } - class SourceBasedFrontend( - name: String, - language: String, - description: String, - extension: String - ) extends Frontend(name, language, description): - - def fromString(str: String, args: List[String] = List()): Cpg = - withCodeInTmpFile(str, "tmp." + extension) { dir => - super.apply(dir.path.toString, args = args) - } match - case Failure(exception) => throw new ConsoleException( - s"unable to generate atom from given String", - exception - ) - case Success(value) => value - class CFrontend(name: String, extension: String = "c") - extends SourceBasedFrontend( - name, - Languages.NEWC, - "Eclipse CDT Based Frontend for C/C++", - extension + cpgMaybe + .map(cpg => console.summary) + .getOrElse( + throw new ConsoleException(s"Error creating project for input path: `$inputPath`") ) - - private def withCodeInTmpFile(str: String, filename: String)(f: File => Cpg): Try[Cpg] = - val dir = File.newTemporaryDirectory("console") - val result = Try { - (dir / filename).write(str) - f(dir) - } - dir.deleteOnExit(swallowIOExceptions = true) - result - - /** Provide an overview of the available CPG generators (frontends) - */ - override def toString: String = - val cols = List("name", "description", "available") - val rows = allFrontends.map { frontend => - List(frontend.name, frontend.description, frontend.isAvailable.toString) - } - "Type `importCode.` to run a specific language frontend\n" + - "\n" + Table(cols, rows).render - - private def apply(generator: CpgGenerator, inputPath: String, projectName: String): Cpg = - checkInputPath(inputPath) - - val name = Option(projectName).filter(_.nonEmpty).getOrElse(deriveNameFromInputPath( - inputPath, - workspace - )) - report(s"Creating project `$name` for code at `$inputPath`") - - val cpgMaybe = workspace.createProject(inputPath, name).flatMap { pathToProject => - val frontendCpgOutFile = pathToProject.resolve(nameOfLegacyCpgInProject) - val frontendAtomPath = - generatorFactory.runGenerator(generator, inputPath, frontendCpgOutFile.toString) - frontendAtomPath match - case Success(_) => - console.open(name).flatMap(_.cpg) - case Failure(exception) => - throw new ConsoleException( - s"Error creating project for input path: `$inputPath`", - exception - ) - } - - cpgMaybe - .map(cpg => console.summary) - .getOrElse( - throw new ConsoleException(s"Error creating project for input path: `$inputPath`") - ) - cpgMaybe.get - end apply + cpgMaybe.get + end apply end ImportCode diff --git a/console/src/main/scala/io/appthreat/console/cpgcreation/JavaSrcCpgGenerator.scala b/console/src/main/scala/io/appthreat/console/cpgcreation/JavaSrcCpgGenerator.scala index afe4b098..6159897c 100644 --- a/console/src/main/scala/io/appthreat/console/cpgcreation/JavaSrcCpgGenerator.scala +++ b/console/src/main/scala/io/appthreat/console/cpgcreation/JavaSrcCpgGenerator.scala @@ -12,25 +12,25 @@ import scala.util.Try /** Source-based front-end for Java */ case class JavaSrcCpgGenerator(config: FrontendConfig, rootPath: Path) extends CpgGenerator: - private lazy val command: Path = - if isWin then rootPath.resolve("javasrc2cpg.bat") else rootPath.resolve("javasrc2cpg") - private var javaConfig: Option[Config] = None + private lazy val command: Path = + if isWin then rootPath.resolve("javasrc2cpg.bat") else rootPath.resolve("javasrc2cpg") + private var javaConfig: Option[Config] = None - /** Generate a CPG for the given input path. Returns the output path, or None, if no CPG was - * generated. - */ - override def generate(inputPath: String, outputPath: String = "cpg.bin"): Try[String] = - val arguments = config.cmdLineParams.toSeq ++ Seq(inputPath, "--output", outputPath) - javaConfig = X2Cpg.parseCommandLine(arguments.toArray, Main.getCmdLineParser, Config()) - runShellCommand(command.toString, arguments).map(_ => outputPath) + /** Generate a CPG for the given input path. Returns the output path, or None, if no CPG was + * generated. + */ + override def generate(inputPath: String, outputPath: String = "cpg.bin"): Try[String] = + val arguments = config.cmdLineParams.toSeq ++ Seq(inputPath, "--output", outputPath) + javaConfig = X2Cpg.parseCommandLine(arguments.toArray, Main.getCmdLineParser, Config()) + runShellCommand(command.toString, arguments).map(_ => outputPath) - override def applyPostProcessingPasses(cpg: Cpg): Cpg = - if javaConfig.forall(_.enableTypeRecovery) then - JavaSrc2Cpg.typeRecoveryPasses(cpg, javaConfig).foreach(_.createAndApply()) - super.applyPostProcessingPasses(cpg) + override def applyPostProcessingPasses(cpg: Cpg): Cpg = + if javaConfig.forall(_.enableTypeRecovery) then + JavaSrc2Cpg.typeRecoveryPasses(cpg, javaConfig).foreach(_.createAndApply()) + super.applyPostProcessingPasses(cpg) - override def isAvailable: Boolean = - command.toFile.exists + override def isAvailable: Boolean = + command.toFile.exists - override def isJvmBased = true + override def isJvmBased = true end JavaSrcCpgGenerator diff --git a/console/src/main/scala/io/appthreat/console/cpgcreation/JsCpgGenerator.scala b/console/src/main/scala/io/appthreat/console/cpgcreation/JsCpgGenerator.scala index 3ddc73f2..ebf7abaa 100644 --- a/console/src/main/scala/io/appthreat/console/cpgcreation/JsCpgGenerator.scala +++ b/console/src/main/scala/io/appthreat/console/cpgcreation/JsCpgGenerator.scala @@ -7,17 +7,17 @@ import java.nio.file.Path import scala.util.Try case class JsCpgGenerator(config: FrontendConfig, rootPath: Path) extends CpgGenerator: - private lazy val command: Path = - if isWin then rootPath.resolve("js2cpg.bat") else rootPath.resolve("js2cpg.sh") + private lazy val command: Path = + if isWin then rootPath.resolve("js2cpg.bat") else rootPath.resolve("js2cpg.sh") - /** Generate a CPG for the given input path. Returns the output path, or None, if no CPG was - * generated. - */ - override def generate(inputPath: String, outputPath: String = "app.atom"): Try[String] = - val arguments = Seq(inputPath, "--output", outputPath) ++ config.cmdLineParams - runShellCommand(command.toString, arguments).map(_ => outputPath) + /** Generate a CPG for the given input path. Returns the output path, or None, if no CPG was + * generated. + */ + override def generate(inputPath: String, outputPath: String = "app.atom"): Try[String] = + val arguments = Seq(inputPath, "--output", outputPath) ++ config.cmdLineParams + runShellCommand(command.toString, arguments).map(_ => outputPath) - override def isAvailable: Boolean = - command.toFile.exists + override def isAvailable: Boolean = + command.toFile.exists - override def isJvmBased = true + override def isJvmBased = true diff --git a/console/src/main/scala/io/appthreat/console/cpgcreation/JsSrcCpgGenerator.scala b/console/src/main/scala/io/appthreat/console/cpgcreation/JsSrcCpgGenerator.scala index fab6b930..09b0a3dd 100644 --- a/console/src/main/scala/io/appthreat/console/cpgcreation/JsSrcCpgGenerator.scala +++ b/console/src/main/scala/io/appthreat/console/cpgcreation/JsSrcCpgGenerator.scala @@ -10,24 +10,24 @@ import java.nio.file.Path import scala.util.Try case class JsSrcCpgGenerator(config: FrontendConfig, rootPath: Path) extends CpgGenerator: - private lazy val command: Path = - if isWin then rootPath.resolve("jssrc2cpg.bat") else rootPath.resolve("jssrc2cpg.sh") - private var jsConfig: Option[Config] = None + private lazy val command: Path = + if isWin then rootPath.resolve("jssrc2cpg.bat") else rootPath.resolve("jssrc2cpg.sh") + private var jsConfig: Option[Config] = None - /** Generate a CPG for the given input path. Returns the output path, or None, if no CPG was - * generated. - */ - override def generate(inputPath: String, outputPath: String = "app.atom"): Try[String] = - val arguments = Seq(inputPath, "--output", outputPath) ++ config.cmdLineParams - jsConfig = X2Cpg.parseCommandLine(arguments.toArray, Frontend.cmdLineParser, Config()) - runShellCommand(command.toString, arguments).map(_ => outputPath) + /** Generate a CPG for the given input path. Returns the output path, or None, if no CPG was + * generated. + */ + override def generate(inputPath: String, outputPath: String = "app.atom"): Try[String] = + val arguments = Seq(inputPath, "--output", outputPath) ++ config.cmdLineParams + jsConfig = X2Cpg.parseCommandLine(arguments.toArray, Frontend.cmdLineParser, Config()) + runShellCommand(command.toString, arguments).map(_ => outputPath) - override def isAvailable: Boolean = - command.toFile.exists + override def isAvailable: Boolean = + command.toFile.exists - override def applyPostProcessingPasses(cpg: Cpg): Cpg = - JsSrc2Cpg.postProcessingPasses(cpg, jsConfig).foreach(_.createAndApply()) - cpg + override def applyPostProcessingPasses(cpg: Cpg): Cpg = + JsSrc2Cpg.postProcessingPasses(cpg, jsConfig).foreach(_.createAndApply()) + cpg - override def isJvmBased = true + override def isJvmBased = true end JsSrcCpgGenerator diff --git a/console/src/main/scala/io/appthreat/console/cpgcreation/PythonSrcCpgGenerator.scala b/console/src/main/scala/io/appthreat/console/cpgcreation/PythonSrcCpgGenerator.scala index 05579d7a..846bd64a 100644 --- a/console/src/main/scala/io/appthreat/console/cpgcreation/PythonSrcCpgGenerator.scala +++ b/console/src/main/scala/io/appthreat/console/cpgcreation/PythonSrcCpgGenerator.scala @@ -11,42 +11,42 @@ import java.nio.file.Path import scala.util.Try case class PythonSrcCpgGenerator(config: FrontendConfig, rootPath: Path) extends CpgGenerator: - private lazy val command: Path = - if isWin then rootPath.resolve("pysrc2cpg.bat") else rootPath.resolve("pysrc2cpg") - private var pyConfig: Option[Py2CpgOnFileSystemConfig] = None - - /** Generate a CPG for the given input path. Returns the output path, or None, if no CPG was - * generated. - */ - override def generate(inputPath: String, outputPath: String = "app.atom"): Try[String] = - val arguments = Seq(inputPath, "-o", outputPath) ++ config.cmdLineParams - pyConfig = X2Cpg.parseCommandLine( - arguments.toArray, - NewMain.getCmdLineParser, - Py2CpgOnFileSystemConfig() - ) - runShellCommand(command.toString, arguments).map(_ => outputPath) - - override def isAvailable: Boolean = - command.toFile.exists - - override def applyPostProcessingPasses(cpg: Cpg): Cpg = - new ImportsPass(cpg).createAndApply() - new ImportResolverPass(cpg).createAndApply() - new DynamicTypeHintFullNamePass(cpg).createAndApply() - new PythonInheritanceNamePass(cpg).createAndApply() - val typeRecoveryConfig = pyConfig match - case Some(config) => - XTypeRecoveryConfig(config.typePropagationIterations, !config.disableDummyTypes) - case None => XTypeRecoveryConfig() - new PythonTypeRecoveryPass(cpg, typeRecoveryConfig).createAndApply() - new PythonTypeHintCallLinker(cpg).createAndApply() - - // Some of passes above create new methods, so, we - // need to run the ASTLinkerPass one more time - new AstLinkerPass(cpg).createAndApply() - - cpg - - override def isJvmBased = true + private lazy val command: Path = + if isWin then rootPath.resolve("pysrc2cpg.bat") else rootPath.resolve("pysrc2cpg") + private var pyConfig: Option[Py2CpgOnFileSystemConfig] = None + + /** Generate a CPG for the given input path. Returns the output path, or None, if no CPG was + * generated. + */ + override def generate(inputPath: String, outputPath: String = "app.atom"): Try[String] = + val arguments = Seq(inputPath, "-o", outputPath) ++ config.cmdLineParams + pyConfig = X2Cpg.parseCommandLine( + arguments.toArray, + NewMain.getCmdLineParser, + Py2CpgOnFileSystemConfig() + ) + runShellCommand(command.toString, arguments).map(_ => outputPath) + + override def isAvailable: Boolean = + command.toFile.exists + + override def applyPostProcessingPasses(cpg: Cpg): Cpg = + new ImportsPass(cpg).createAndApply() + new ImportResolverPass(cpg).createAndApply() + new DynamicTypeHintFullNamePass(cpg).createAndApply() + new PythonInheritanceNamePass(cpg).createAndApply() + val typeRecoveryConfig = pyConfig match + case Some(config) => + XTypeRecoveryConfig(config.typePropagationIterations, !config.disableDummyTypes) + case None => XTypeRecoveryConfig() + new PythonTypeRecoveryPass(cpg, typeRecoveryConfig).createAndApply() + new PythonTypeHintCallLinker(cpg).createAndApply() + + // Some of passes above create new methods, so, we + // need to run the ASTLinkerPass one more time + new AstLinkerPass(cpg).createAndApply() + + cpg + + override def isJvmBased = true end PythonSrcCpgGenerator diff --git a/console/src/main/scala/io/appthreat/console/cpgcreation/package.scala b/console/src/main/scala/io/appthreat/console/cpgcreation/package.scala index 28791221..704852fb 100644 --- a/console/src/main/scala/io/appthreat/console/cpgcreation/package.scala +++ b/console/src/main/scala/io/appthreat/console/cpgcreation/package.scala @@ -8,95 +8,96 @@ import scala.collection.mutable package object cpgcreation: - /** For a given language, return CPG generator script - */ - def cpgGeneratorForLanguage( - language: String, - config: FrontendConfig, - rootPath: Path, - args: List[String] - ): Option[CpgGenerator] = - lazy val conf = config.withArgs(args) - language.toUpperCase() match - case Languages.C | Languages.NEWC | Languages.JAVA | Languages.JAVASRC | Languages.JSSRC | Languages.JAVASCRIPT | - Languages.PYTHON | Languages.PYTHONSRC => - Some(AtomGenerator(conf, rootPath, language)) - case _ => None - - /** Heuristically determines language by inspecting file/dir at path. - */ - def guessLanguage(path: String): Option[String] = - val file = File(path) - if file.isDirectory then - guessMajorityLanguageInDir(file) - else - guessLanguageForRegularFile(file) - - /** Guess the main language for an entire directory (e.g. a whole project), based on a group - * count of all individual files. Rationale: many projects contain files from different - * languages, but most often one language is standing out in numbers. - */ - private def guessMajorityLanguageInDir(directory: File): Option[String] = - assert(directory.isDirectory, s"$directory must be a directory, but wasn't") - val groupCount = mutable.Map.empty[String, Int].withDefaultValue(0) - - for - file <- directory.listRecursively - if file.isRegularFile - guessedLanguage <- guessLanguageForRegularFile(file) - do - val oldValue = groupCount(guessedLanguage) - groupCount.update(guessedLanguage, oldValue + 1) - - groupCount.toSeq.sortBy(_._2).lastOption.map(_._1) - - private def isJavaBinary(filename: String): Boolean = - Seq(".jar", ".war", ".ear", ".apk").exists(filename.endsWith) - - private def isCsharpFile(filename: String): Boolean = - Seq(".csproj", ".cs").exists(filename.endsWith) - - private def isGoFile(filename: String): Boolean = - filename.endsWith(".go") || Set("gopkg.lock", "gopkg.toml", "go.mod", "go.sum").contains( - filename - ) - - private def isLlvmFile(filename: String): Boolean = - Seq(".bc", ".ll").exists(filename.endsWith) - - private def isJsFile(filename: String): Boolean = - Seq(".js", ".ts", ".jsx", ".tsx").exists(filename.endsWith) || filename == "package.json" - - /** check if given filename looks like it might be a C/CPP source or header file mostly copied - * from io.appthreat.c2cpg.parser.FileDefaults - */ - private def isCFile(filename: String): Boolean = - Seq(".c", ".cc", ".cpp", ".h", ".hpp", ".hh", ".ccm", ".cxxm", ".c++m").exists( - filename.endsWith - ) - - private def isYamlFile(filename: String): Boolean = - Seq(".yml", ".yaml").exists(filename.endsWith) - - private def isBomFile(filename: String): Boolean = - Seq("bom.json", ".cdx.json").exists(filename.endsWith) - - private def guessLanguageForRegularFile(file: File): Option[String] = - file.name.toLowerCase match - case f if isJavaBinary(f) => Some(Languages.JAVA) - case f if isCsharpFile(f) => Some(Languages.CSHARP) - case f if isGoFile(f) => Some(Languages.GOLANG) - case f if isJsFile(f) => Some(Languages.JSSRC) - case f if f.endsWith(".java") => Some(Languages.JAVASRC) - case f if f.endsWith(".class") => Some(Languages.JAVA) - case f if f.endsWith(".kt") => Some(Languages.KOTLIN) - case f if f.endsWith(".php") => Some(Languages.PHP) - case f if f.endsWith(".py") => Some(Languages.PYTHONSRC) - case f if f.endsWith(".rb") => Some(Languages.RUBYSRC) - case f if isLlvmFile(f) => Some(Languages.LLVM) - case f if isCFile(f) => Some(Languages.NEWC) - case f if isBomFile(f) => Option("BOM") - case f if f.endsWith(".json") => Option("JSON") - case f if isYamlFile(f) => Option("YAML") - case _ => None + /** For a given language, return CPG generator script + */ + def cpgGeneratorForLanguage( + language: String, + config: FrontendConfig, + rootPath: Path, + args: List[String] + ): Option[CpgGenerator] = + lazy val conf = config.withArgs(args) + language.toUpperCase() match + case Languages.C | Languages.NEWC | Languages.JAVA | Languages.JAVASRC | Languages + .JSSRC | Languages.JAVASCRIPT | + Languages.PYTHON | Languages.PYTHONSRC => + Some(AtomGenerator(conf, rootPath, language)) + case _ => None + + /** Heuristically determines language by inspecting file/dir at path. + */ + def guessLanguage(path: String): Option[String] = + val file = File(path) + if file.isDirectory then + guessMajorityLanguageInDir(file) + else + guessLanguageForRegularFile(file) + + /** Guess the main language for an entire directory (e.g. a whole project), based on a group count + * of all individual files. Rationale: many projects contain files from different languages, but + * most often one language is standing out in numbers. + */ + private def guessMajorityLanguageInDir(directory: File): Option[String] = + assert(directory.isDirectory, s"$directory must be a directory, but wasn't") + val groupCount = mutable.Map.empty[String, Int].withDefaultValue(0) + + for + file <- directory.listRecursively + if file.isRegularFile + guessedLanguage <- guessLanguageForRegularFile(file) + do + val oldValue = groupCount(guessedLanguage) + groupCount.update(guessedLanguage, oldValue + 1) + + groupCount.toSeq.sortBy(_._2).lastOption.map(_._1) + + private def isJavaBinary(filename: String): Boolean = + Seq(".jar", ".war", ".ear", ".apk").exists(filename.endsWith) + + private def isCsharpFile(filename: String): Boolean = + Seq(".csproj", ".cs").exists(filename.endsWith) + + private def isGoFile(filename: String): Boolean = + filename.endsWith(".go") || Set("gopkg.lock", "gopkg.toml", "go.mod", "go.sum").contains( + filename + ) + + private def isLlvmFile(filename: String): Boolean = + Seq(".bc", ".ll").exists(filename.endsWith) + + private def isJsFile(filename: String): Boolean = + Seq(".js", ".ts", ".jsx", ".tsx").exists(filename.endsWith) || filename == "package.json" + + /** check if given filename looks like it might be a C/CPP source or header file mostly copied + * from io.appthreat.c2cpg.parser.FileDefaults + */ + private def isCFile(filename: String): Boolean = + Seq(".c", ".cc", ".cpp", ".h", ".hpp", ".hh", ".ccm", ".cxxm", ".c++m").exists( + filename.endsWith + ) + + private def isYamlFile(filename: String): Boolean = + Seq(".yml", ".yaml").exists(filename.endsWith) + + private def isBomFile(filename: String): Boolean = + Seq("bom.json", ".cdx.json").exists(filename.endsWith) + + private def guessLanguageForRegularFile(file: File): Option[String] = + file.name.toLowerCase match + case f if isJavaBinary(f) => Some(Languages.JAVA) + case f if isCsharpFile(f) => Some(Languages.CSHARP) + case f if isGoFile(f) => Some(Languages.GOLANG) + case f if isJsFile(f) => Some(Languages.JSSRC) + case f if f.endsWith(".java") => Some(Languages.JAVASRC) + case f if f.endsWith(".class") => Some(Languages.JAVA) + case f if f.endsWith(".kt") => Some(Languages.KOTLIN) + case f if f.endsWith(".php") => Some(Languages.PHP) + case f if f.endsWith(".py") => Some(Languages.PYTHONSRC) + case f if f.endsWith(".rb") => Some(Languages.RUBYSRC) + case f if isLlvmFile(f) => Some(Languages.LLVM) + case f if isCFile(f) => Some(Languages.NEWC) + case f if isBomFile(f) => Option("BOM") + case f if f.endsWith(".json") => Option("JSON") + case f if isYamlFile(f) => Option("YAML") + case _ => None end cpgcreation diff --git a/console/src/main/scala/io/appthreat/console/package.scala b/console/src/main/scala/io/appthreat/console/package.scala index e1d4a775..6f4e22e4 100644 --- a/console/src/main/scala/io/appthreat/console/package.scala +++ b/console/src/main/scala/io/appthreat/console/package.scala @@ -6,37 +6,35 @@ import replpp.Colors // TODO remove any time after the end of 2023 - this is completely deprecated package object console: - implicit class UnixUtils[A](content: Iterable[A]): - given Colors = Colors.Default - - /** Iterate over left hand side operand and write to file. Think of it as the Ocular version - * of the Unix `>` shell redirection. - */ - @deprecated("please use `#>` instead", "2.0.45 (August 2023)") - def |>(outfile: String): Unit = - content #> outfile - - /** Iterate over left hand side operand and append to file. Think of it as the Ocular - * version of the Unix `>>` shell redirection. - */ - @deprecated("please use `#>>` instead", "2.0.45 (August 2023)") - def |>>(outfile: String): Unit = - content #>> outfile - - implicit class StringOps(value: String): - given Colors = Colors.Default - - /** Pipe string to file. Think of it as the Ocular version of the Unix `>` shell - * redirection. - */ - @deprecated("please use `#>` instead", "2.0.45 (August 2023)") - def |>(outfile: String): Unit = - value #> outfile - - /** Append string to file. Think of it as the Ocular version of the Unix `>>` shell - * redirection. - */ - @deprecated("please use `#>>` instead", "2.0.45 (August 2023)") - def |>>(outfile: String): Unit = - value #>> outfile + implicit class UnixUtils[A](content: Iterable[A]): + given Colors = Colors.Default + + /** Iterate over left hand side operand and write to file. Think of it as the Ocular version of + * the Unix `>` shell redirection. + */ + @deprecated("please use `#>` instead", "2.0.45 (August 2023)") + def |>(outfile: String): Unit = + content #> outfile + + /** Iterate over left hand side operand and append to file. Think of it as the Ocular version of + * the Unix `>>` shell redirection. + */ + @deprecated("please use `#>>` instead", "2.0.45 (August 2023)") + def |>>(outfile: String): Unit = + content #>> outfile + + implicit class StringOps(value: String): + given Colors = Colors.Default + + /** Pipe string to file. Think of it as the Ocular version of the Unix `>` shell redirection. + */ + @deprecated("please use `#>` instead", "2.0.45 (August 2023)") + def |>(outfile: String): Unit = + value #> outfile + + /** Append string to file. Think of it as the Ocular version of the Unix `>>` shell redirection. + */ + @deprecated("please use `#>>` instead", "2.0.45 (August 2023)") + def |>>(outfile: String): Unit = + value #>> outfile end console diff --git a/console/src/main/scala/io/appthreat/console/workspacehandling/Project.scala b/console/src/main/scala/io/appthreat/console/workspacehandling/Project.scala index 86a21893..870ca76b 100644 --- a/console/src/main/scala/io/appthreat/console/workspacehandling/Project.scala +++ b/console/src/main/scala/io/appthreat/console/workspacehandling/Project.scala @@ -8,8 +8,8 @@ import io.shiftleft.semanticcpg.Overlays import java.nio.file.Path object Project: - val workCpgFileName = "cpg.bin.tmp" - val persistentCpgFileName = "cpg.bin" + val workCpgFileName = "cpg.bin.tmp" + val persistentCpgFileName = "cpg.bin" case class ProjectFile(inputPath: String, name: String) @@ -20,42 +20,42 @@ case class ProjectFile(inputPath: String, name: String) */ case class Project(projectFile: ProjectFile, var path: Path, var cpg: Option[Cpg] = None): - import Project.* + import Project.* - def name: String = projectFile.name + def name: String = projectFile.name - def inputPath: String = projectFile.inputPath + def inputPath: String = projectFile.inputPath - def isOpen: Boolean = cpg.isDefined + def isOpen: Boolean = cpg.isDefined - def appliedOverlays: Seq[String] = - cpg.map(Overlays.appliedOverlays).getOrElse(Nil) + def appliedOverlays: Seq[String] = + cpg.map(Overlays.appliedOverlays).getOrElse(Nil) - def availableOverlays: List[String] = - File(path.resolve("overlays")).list.map(_.name).toList + def availableOverlays: List[String] = + File(path.resolve("overlays")).list.map(_.name).toList - def overlayDirs: Seq[File] = - val overlayDir = File(path.resolve("overlays")) - appliedOverlays.map(o => overlayDir / o) + def overlayDirs: Seq[File] = + val overlayDir = File(path.resolve("overlays")) + appliedOverlays.map(o => overlayDir / o) - override def toString: String = - toTableRow.mkString("\t") + override def toString: String = + toTableRow.mkString("\t") - def toTableRow: List[String] = - val cpgLoaded = cpg.isDefined - val overlays = availableOverlays.mkString(",") - val inputPath = projectFile.inputPath - List(name, overlays, inputPath, cpgLoaded.toString) + def toTableRow: List[String] = + val cpgLoaded = cpg.isDefined + val overlays = availableOverlays.mkString(",") + val inputPath = projectFile.inputPath + List(name, overlays, inputPath, cpgLoaded.toString) - /** Close project if it is open and do nothing otherwise. - */ - def close: Project = - cpg.foreach { c => - c.close() - val workingCopy = path.resolve(workCpgFileName) - val persistent = path.resolve(persistentCpgFileName) - cp(workingCopy, persistent) - } - cpg = None - this + /** Close project if it is open and do nothing otherwise. + */ + def close: Project = + cpg.foreach { c => + c.close() + val workingCopy = path.resolve(workCpgFileName) + val persistent = path.resolve(persistentCpgFileName) + cp(workingCopy, persistent) + } + cpg = None + this end Project diff --git a/console/src/main/scala/io/appthreat/console/workspacehandling/Workspace.scala b/console/src/main/scala/io/appthreat/console/workspacehandling/Workspace.scala index 1d96fd80..3519c034 100644 --- a/console/src/main/scala/io/appthreat/console/workspacehandling/Workspace.scala +++ b/console/src/main/scala/io/appthreat/console/workspacehandling/Workspace.scala @@ -11,20 +11,20 @@ import scala.collection.mutable.ListBuffer */ class Workspace[ProjectType <: Project](var projects: ListBuffer[ProjectType]): - /** Returns total number of projects in this workspace - */ - def numberOfProjects: Int = projects.size + /** Returns total number of projects in this workspace + */ + def numberOfProjects: Int = projects.size - /** Provide a human-readable overview of the workspace - */ - override def toString: String = - if projects.isEmpty then - System.err.println( - "The workpace is empty. Use `importCode` or `importAtom` to populate it" - ) - "empty" - else - """ + /** Provide a human-readable overview of the workspace + */ + override def toString: String = + if projects.isEmpty then + System.err.println( + "The workpace is empty. Use `importCode` or `importAtom` to populate it" + ) + "empty" + else + """ |Overview of all projects present in your workspace. You can use `open` and `close` |to load and unload projects respectively. `cpgs` allows you to query all projects |at once. `cpg` points to the Code Property Graph of the *selected* project, which is @@ -33,8 +33,8 @@ class Workspace[ProjectType <: Project](var projects: ListBuffer[ProjectType]): | | Type `run` to add additional overlays to code property graphs |""".stripMargin - "\n" + Table( - columnNames = List("name", "overlays", "inputPath", "open"), - rows = projects.map(_.toTableRow).toList - ).render + "\n" + Table( + columnNames = List("name", "overlays", "inputPath", "open"), + rows = projects.map(_.toTableRow).toList + ).render end Workspace diff --git a/console/src/main/scala/io/appthreat/console/workspacehandling/WorkspaceLoader.scala b/console/src/main/scala/io/appthreat/console/workspacehandling/WorkspaceLoader.scala index 36ea4f8a..4e22dbf7 100644 --- a/console/src/main/scala/io/appthreat/console/workspacehandling/WorkspaceLoader.scala +++ b/console/src/main/scala/io/appthreat/console/workspacehandling/WorkspaceLoader.scala @@ -11,40 +11,40 @@ import scala.util.{Failure, Success, Try} */ abstract class WorkspaceLoader[ProjectType <: Project]: - /** Initialize workspace from a directory - * @param path - * path to the directory - */ - def load(path: String): Workspace[ProjectType] = - val dirFile = File(path) - val dirPath = dirFile.path.toAbsolutePath - - mkdirs(dirFile) - new Workspace(ListBuffer.from(loadProjectsFromFs(dirPath))) - - private def loadProjectsFromFs(cpgsPath: Path): LazyList[ProjectType] = - cpgsPath.toFile.listFiles - .filter(_.isDirectory) - .to(LazyList) - .flatMap(f => loadProject(f.toPath)) - - def loadProject(path: Path): Option[ProjectType] = - Try { - val projectFile = readProjectFile(path) - createProject(projectFile, path) - } match - case Success(v) => Some(v) - case Failure(e) => - System.err.println(s"Error loading project at $path - skipping: ") - e.printStackTrace - None - - def createProject(projectFile: ProjectFile, path: Path): ProjectType - - private val PROJECTFILE_NAME = "project.json" - - private def readProjectFile(projectDirName: Path): ProjectFile = - // TODO see `writeProjectFile` - val data = ujson.read(projectDirName.resolve(PROJECTFILE_NAME)) - ProjectFile(data("inputPath").str, data("name").str) + /** Initialize workspace from a directory + * @param path + * path to the directory + */ + def load(path: String): Workspace[ProjectType] = + val dirFile = File(path) + val dirPath = dirFile.path.toAbsolutePath + + mkdirs(dirFile) + new Workspace(ListBuffer.from(loadProjectsFromFs(dirPath))) + + private def loadProjectsFromFs(cpgsPath: Path): LazyList[ProjectType] = + cpgsPath.toFile.listFiles + .filter(_.isDirectory) + .to(LazyList) + .flatMap(f => loadProject(f.toPath)) + + def loadProject(path: Path): Option[ProjectType] = + Try { + val projectFile = readProjectFile(path) + createProject(projectFile, path) + } match + case Success(v) => Some(v) + case Failure(e) => + System.err.println(s"Error loading project at $path - skipping: ") + e.printStackTrace + None + + def createProject(projectFile: ProjectFile, path: Path): ProjectType + + private val PROJECTFILE_NAME = "project.json" + + private def readProjectFile(projectDirName: Path): ProjectFile = + // TODO see `writeProjectFile` + val data = ujson.read(projectDirName.resolve(PROJECTFILE_NAME)) + ProjectFile(data("inputPath").str, data("name").str) end WorkspaceLoader diff --git a/console/src/main/scala/io/appthreat/console/workspacehandling/WorkspaceManager.scala b/console/src/main/scala/io/appthreat/console/workspacehandling/WorkspaceManager.scala index bd55731a..46e6746d 100644 --- a/console/src/main/scala/io/appthreat/console/workspacehandling/WorkspaceManager.scala +++ b/console/src/main/scala/io/appthreat/console/workspacehandling/WorkspaceManager.scala @@ -16,8 +16,8 @@ import scala.collection.mutable.ListBuffer import scala.util.{Failure, Success, Try} object DefaultLoader extends WorkspaceLoader[Project]: - override def createProject(projectFile: ProjectFile, path: Path): Project = - Project(projectFile, path) + override def createProject(projectFile: ProjectFile, path: Path): Project = + Project(projectFile, path) /** WorkspaceManager: a component, which loads and maintains the list of projects made accessible * via Ocular/Joern. @@ -30,356 +30,356 @@ class WorkspaceManager[ProjectType <: Project]( loader: WorkspaceLoader[ProjectType] = DefaultLoader ) extends Reporting: - def getPath: String = path - - import WorkspaceManager.* - - /** The workspace managed by this WorkspaceManager - */ - private var workspace: Workspace[ProjectType] = loader.load(path) - private val dirPath = File(path).path.toAbsolutePath - - private val LEGACY_BASE_CPG_FILENAME = "app.atom" - private val OVERLAY_DIR_NAME = "overlays" - - /** Create project for code stored at `inputPath` with the project name `name`. If `name` is - * empty, the project name is derived from `inputPath`. If a project for this `name` already - * exists, it is deleted from the workspace first. If no file or directory exists at - * `inputPath`, then no project is created. Returns the path to the project directory as an - * optional String, and None if there was an error. - */ - def createProject(inputPath: String, projectName: String): Option[Path] = - Some(File(inputPath)).filter(_.exists).map { _ => - val pathToProject = projectNameToDir(projectName) - - if project(projectName).isDefined then - removeProject(projectName) - - createProjectDirectory(inputPath, name = projectName) - loader.loadProject(pathToProject).foreach(addProjectToProjectsList) - pathToProject - } - - /** Remove project named `name` from disk - * @param name - * name of the project - */ - def removeProject(name: String): Unit = - closeProject(name) - removeProjectFromList(name) - File(projectNameToDir(name)).delete() - - /** Create new project directory containing project file. - */ - private def createProjectDirectory(inputPath: String, name: String): Unit = - val dirPath = projectNameToDir(name) - mkdirs(dirPath) - val absoluteInputPath = File(inputPath).path.toAbsolutePath.toString - mkdirs(File(overlayDir(name))) - val projectFile = ProjectFile(absoluteInputPath, name) - writeProjectFile(projectFile, dirPath) - touch(dirPath.toString / "cpg.bin") - - /** Write the project's `project.json`, a JSON file that holds meta information. - */ - private def writeProjectFile(projectFile: ProjectFile, dirPath: Path): File = - // TODO proguard and json4s don't play along. We actually want to - // serialize the case class ProjectFile here, but it comes out - // empty. This code will be moved to `codepropertgraph` at - // which point serialization should work. - // val content = jsonWrite(projectFile) - implicit val formats: DefaultFormats.type = DefaultFormats - val PROJECTFILE_NAME = "project.json" - val content = - jsonWrite(Map("inputPath" -> projectFile.inputPath, "name" -> projectFile.name)) - File(dirPath.resolve(PROJECTFILE_NAME)).write(content) - - /** Delete the workspace from disk, then initialize it again. - */ - def reset: Unit = - Try(cpg.close()) - deleteWorkspace() - workspace = loader.load(path) - - private def deleteWorkspace(): Unit = - if dirPath == null || dirPath.toString == "" then - throw new RuntimeException("dirPath is not set") - val dirFile = better.files.File(dirPath.toAbsolutePath.toString) - if !dirFile.exists then - throw new RuntimeException(s"Directory ${dirFile.toString} does not exist") - - dirFile.delete() - - /** Return the number of projects currently present in this workspace. - */ - def numberOfProjects: Int = workspace.projects.size - - def projects: List[Project] = workspace.projects.toList - - def project(name: String): Option[Project] = workspace.projects.find(_.name == name) - - override def toString: String = workspace.toString - - private def getProjectByPath(projectPath: Path): Option[Project] = - workspace.projects.find(_.path.toAbsolutePath == projectPath.toAbsolutePath) - - /** A sorted list of all loaded CPGs - */ - def loadedCpgs: List[Cpg] = workspace.projects.flatMap(_.cpg).toList - - /** Indicates whether a workspace record exists for @inputPath. - */ - def projectExists(inputPath: String): Boolean = projectDir(inputPath).toFile.exists - - /** Indicates whether a base CPG exists for @inputPath. - */ - def cpgExists(inputPath: String, isLegacy: Boolean = false): Boolean = - val baseFileName = if isLegacy then LEGACY_BASE_CPG_FILENAME else BASE_CPG_FILENAME - projectExists(inputPath) && - projectDir(inputPath).resolve(baseFileName).toFile.exists - - /** Overlay directory for CPG with given @inputPath - */ - def overlayDir(inputPath: String): String = - projectDir(inputPath).resolve(OVERLAY_DIR_NAME).toString - - def overlayDirByProjectName(name: String): String = - projectNameToDir(name).resolve(OVERLAY_DIR_NAME).toString - - /** Filename for the base CPG for @inputPath - */ - private def baseCpgFilename(inputPath: String, isLegacy: Boolean = false): String = - val baseFileName = if isLegacy then LEGACY_BASE_CPG_FILENAME else BASE_CPG_FILENAME - projectDir(inputPath).resolve(baseFileName).toString - - /** The safe directory name for a given input file that can be used to store its base CPG, along - * with all overlays. This method returns a directory name regardless of whether the directory - * exists or not. - */ - private def projectDir(inputPath: String): Path = - val filename = File(inputPath).path.getFileName - if filename == null then - throw new RuntimeException("invalid input path: " + inputPath) - dirPath - .resolve(URLEncoder.encode(filename.toString, DefaultCharset.toString)) - .toAbsolutePath - - private def projectNameToDir(name: String): Path = - dirPath - .resolve(URLEncoder.encode(name, DefaultCharset.toString)) - .toAbsolutePath - - /** Record for the given name. None if it is not in the workspace. - */ - private def projectByName(name: String): Option[ProjectType] = - workspace.projects.find(r => r.name == name) - - /** Workspace record for the CPG, or none, if the CPG is not in the workspace - */ - def projectByCpg(baseCpg: Cpg): Option[ProjectType] = - workspace.projects.find(_.cpg.contains(baseCpg)) - - def projectExistsForCpg(baseCpg: Cpg): Boolean = projectByCpg(baseCpg).isDefined - - def getNextOverlayDirName(baseCpg: Cpg, overlayName: String): String = - val project = projectByCpg(baseCpg).get - val overlayDirectory = File(overlayDirByProjectName(project.name)) - - val overlayFile = overlayDirectory.path - .resolve(overlayName) - .toFile - - overlayFile.getAbsolutePath - - /** Obtain the cpg that was last loaded. Throws a runtime exception if no CPG has been loaded. - */ - def cpg: Cpg = - val project = workspace.projects.lastOption - project match - case Some(p) => - p.cpg match - case Some(value) => value - case None => - throw new RuntimeException( - s"No Atom loaded for exploration - try importing one using `importCode|importAtom`" - ) - case None => throw new RuntimeException("No projects loaded") - - /** Set active project to project with name `name`. If a project with this name does not exist, - * does nothing. - */ - def setActiveProject(name: String): Option[ProjectType] = - val project = projectByName(name) - if project.isEmpty then - System.err.println(s"Error: project with name $name does not exist") - None - else - removeProjectFromList(name).map { p => - addProjectToProjectsList(p) - p - } - - /** Retrieve the currently active project. If no project is active, None is returned. - */ - def getActiveProject: Option[Project] = - workspace.projects.lastOption - - /** Open project by name and return it. If a project with this name does not exist, None is - * returned. If the CPG of this project is loaded, it is unloaded first and then reloaded. - * Returns project or None on error. - * - * @param name - * of the project to load - * @param loader - * function to perform CPG loading. This parameter only exists for testing purposes. - */ - def openProject( - name: String, - loader: String => Option[Cpg] = { x => - loadCpgRaw(x) + def getPath: String = path + + import WorkspaceManager.* + + /** The workspace managed by this WorkspaceManager + */ + private var workspace: Workspace[ProjectType] = loader.load(path) + private val dirPath = File(path).path.toAbsolutePath + + private val LEGACY_BASE_CPG_FILENAME = "app.atom" + private val OVERLAY_DIR_NAME = "overlays" + + /** Create project for code stored at `inputPath` with the project name `name`. If `name` is + * empty, the project name is derived from `inputPath`. If a project for this `name` already + * exists, it is deleted from the workspace first. If no file or directory exists at `inputPath`, + * then no project is created. Returns the path to the project directory as an optional String, + * and None if there was an error. + */ + def createProject(inputPath: String, projectName: String): Option[Path] = + Some(File(inputPath)).filter(_.exists).map { _ => + val pathToProject = projectNameToDir(projectName) + + if project(projectName).isDefined then + removeProject(projectName) + + createProjectDirectory(inputPath, name = projectName) + loader.loadProject(pathToProject).foreach(addProjectToProjectsList) + pathToProject } - ): Option[Project] = - if !projectExists(name) then - report( - s"Project does not exist in workspace. Try `importCode/importAtom(inputPath)` to create it" - ) - None - else if !File(baseCpgFilename(name)).exists then - report(s"CPG for project $name does not exist at ${baseCpgFilename(name)}, bailing out") - None - else if project(name).exists(_.cpg.isDefined) then - setActiveProject(name) - project(name) - else - val cpgFilename = baseCpgFilename(name) - val cpgFile = File(cpgFilename) - val workingCopyPath = projectDir(name).resolve(Project.workCpgFileName) - val workingCopyName = workingCopyPath.toAbsolutePath.toString - cp(cpgFile, workingCopyPath) - - val result = - val newCpg = loader(workingCopyName) - val projectPath = File(workingCopyName).parent.path - newCpg.flatMap { c => - unloadCpgIfExists(name) - setCpgForProject(c, projectPath) - projectByCpg(c) - } - result - - /** Free up resources occupied by this project but do not remove project from disk. - */ - def closeProject(name: String): Option[Project] = - projectByName(name).map(_.close) - - /** Set CPG for existing project. It is assumed that the CPG is loaded. - */ - private def setCpgForProject(newCpg: Cpg, projectPath: Path): Unit = - val project = getProjectByPath(projectPath) - project match - case Some(p) => - p.cpg = Some(newCpg) - setActiveProject(p.name) + + /** Remove project named `name` from disk + * @param name + * name of the project + */ + def removeProject(name: String): Unit = + closeProject(name) + removeProjectFromList(name) + File(projectNameToDir(name)).delete() + + /** Create new project directory containing project file. + */ + private def createProjectDirectory(inputPath: String, name: String): Unit = + val dirPath = projectNameToDir(name) + mkdirs(dirPath) + val absoluteInputPath = File(inputPath).path.toAbsolutePath.toString + mkdirs(File(overlayDir(name))) + val projectFile = ProjectFile(absoluteInputPath, name) + writeProjectFile(projectFile, dirPath) + touch(dirPath.toString / "cpg.bin") + + /** Write the project's `project.json`, a JSON file that holds meta information. + */ + private def writeProjectFile(projectFile: ProjectFile, dirPath: Path): File = + // TODO proguard and json4s don't play along. We actually want to + // serialize the case class ProjectFile here, but it comes out + // empty. This code will be moved to `codepropertgraph` at + // which point serialization should work. + // val content = jsonWrite(projectFile) + implicit val formats: DefaultFormats.type = DefaultFormats + val PROJECTFILE_NAME = "project.json" + val content = + jsonWrite(Map("inputPath" -> projectFile.inputPath, "name" -> projectFile.name)) + File(dirPath.resolve(PROJECTFILE_NAME)).write(content) + + /** Delete the workspace from disk, then initialize it again. + */ + def reset: Unit = + Try(cpg.close()) + deleteWorkspace() + workspace = loader.load(path) + + private def deleteWorkspace(): Unit = + if dirPath == null || dirPath.toString == "" then + throw new RuntimeException("dirPath is not set") + val dirFile = better.files.File(dirPath.toAbsolutePath.toString) + if !dirFile.exists then + throw new RuntimeException(s"Directory ${dirFile.toString} does not exist") + + dirFile.delete() + + /** Return the number of projects currently present in this workspace. + */ + def numberOfProjects: Int = workspace.projects.size + + def projects: List[Project] = workspace.projects.toList + + def project(name: String): Option[Project] = workspace.projects.find(_.name == name) + + override def toString: String = workspace.toString + + private def getProjectByPath(projectPath: Path): Option[Project] = + workspace.projects.find(_.path.toAbsolutePath == projectPath.toAbsolutePath) + + /** A sorted list of all loaded CPGs + */ + def loadedCpgs: List[Cpg] = workspace.projects.flatMap(_.cpg).toList + + /** Indicates whether a workspace record exists for @inputPath. + */ + def projectExists(inputPath: String): Boolean = projectDir(inputPath).toFile.exists + + /** Indicates whether a base CPG exists for @inputPath. + */ + def cpgExists(inputPath: String, isLegacy: Boolean = false): Boolean = + val baseFileName = if isLegacy then LEGACY_BASE_CPG_FILENAME else BASE_CPG_FILENAME + projectExists(inputPath) && + projectDir(inputPath).resolve(baseFileName).toFile.exists + + /** Overlay directory for CPG with given @inputPath + */ + def overlayDir(inputPath: String): String = + projectDir(inputPath).resolve(OVERLAY_DIR_NAME).toString + + def overlayDirByProjectName(name: String): String = + projectNameToDir(name).resolve(OVERLAY_DIR_NAME).toString + + /** Filename for the base CPG for @inputPath + */ + private def baseCpgFilename(inputPath: String, isLegacy: Boolean = false): String = + val baseFileName = if isLegacy then LEGACY_BASE_CPG_FILENAME else BASE_CPG_FILENAME + projectDir(inputPath).resolve(baseFileName).toString + + /** The safe directory name for a given input file that can be used to store its base CPG, along + * with all overlays. This method returns a directory name regardless of whether the directory + * exists or not. + */ + private def projectDir(inputPath: String): Path = + val filename = File(inputPath).path.getFileName + if filename == null then + throw new RuntimeException("invalid input path: " + inputPath) + dirPath + .resolve(URLEncoder.encode(filename.toString, DefaultCharset.toString)) + .toAbsolutePath + + private def projectNameToDir(name: String): Path = + dirPath + .resolve(URLEncoder.encode(name, DefaultCharset.toString)) + .toAbsolutePath + + /** Record for the given name. None if it is not in the workspace. + */ + private def projectByName(name: String): Option[ProjectType] = + workspace.projects.find(r => r.name == name) + + /** Workspace record for the CPG, or none, if the CPG is not in the workspace + */ + def projectByCpg(baseCpg: Cpg): Option[ProjectType] = + workspace.projects.find(_.cpg.contains(baseCpg)) + + def projectExistsForCpg(baseCpg: Cpg): Boolean = projectByCpg(baseCpg).isDefined + + def getNextOverlayDirName(baseCpg: Cpg, overlayName: String): String = + val project = projectByCpg(baseCpg).get + val overlayDirectory = File(overlayDirByProjectName(project.name)) + + val overlayFile = overlayDirectory.path + .resolve(overlayName) + .toFile + + overlayFile.getAbsolutePath + + /** Obtain the cpg that was last loaded. Throws a runtime exception if no CPG has been loaded. + */ + def cpg: Cpg = + val project = workspace.projects.lastOption + project match + case Some(p) => + p.cpg match + case Some(value) => value case None => - System.err.println( - s"Error setting CPG for non-existing/unloaded project at $projectPath" + throw new RuntimeException( + s"No Atom loaded for exploration - try importing one using `importCode|importAtom`" ) + case None => throw new RuntimeException("No projects loaded") + + /** Set active project to project with name `name`. If a project with this name does not exist, + * does nothing. + */ + def setActiveProject(name: String): Option[ProjectType] = + val project = projectByName(name) + if project.isEmpty then + System.err.println(s"Error: project with name $name does not exist") + None + else + removeProjectFromList(name).map { p => + addProjectToProjectsList(p) + p + } - private def loadCpgRaw(cpgFilename: String): Option[Cpg] = - Try { - val odbConfig = Config.withDefaults.withStorageLocation(cpgFilename) - val config = - CpgLoaderConfig.withDefaults.doNotCreateIndexesOnLoad.withOverflowConfig(odbConfig) - val newCpg = CpgLoader.loadFromOverflowDb(config) - CpgLoader.createIndexes(newCpg) - newCpg - } match - case Success(v) => Some(v) - case Failure(ex) => - System.err.println("Error loading CPG") - System.err.println(ex) - None - - private def addProjectToProjectsList(project: ProjectType): ListBuffer[ProjectType] = - workspace.projects += project - - def unloadCpgByProjectName(name: String): Unit = - projectByName(name).foreach { record => - record.cpg.foreach(_.close) - record.cpg = None - } - - def reloadCpgByName(name: String, loadCpg: String => Option[Cpg]): Option[Cpg] = - projectByName(name).flatMap { record => - record.cpg = loadCpg(record.name) - record.cpg - } - - private def unloadCpgIfExists(name: String): Unit = - projectByName(File(projectDir(name)).name) - .flatMap(_.cpg) - .foreach { c => - try - c.close - catch - case _: IllegalStateException => // Store is already closed - } - - /** Remove currently active project from workspace and delete all associated workspace files - * from disk. - */ - def deleteCurrentProject(): Unit = - val project = projectByCpg(cpg) - project match - case Some(p) => deleteProject(p) - case None => - report(s"Project for active CPG does not exist") - - /** Remove project with name `name` from workspace and delete all associated workspace files - * from disk. - * @param name - * the name of the project that should be removed - */ - def deleteProject(name: String): Option[Unit] = - val project = projectByName(name) - project match - case Some(p) => - deleteProject(p) - Option[Unit](()) - case None => - report(s"Project with name $name does not exist") - None - - private def deleteProject(project: Project): Unit = - removeProjectFromList(project.name) - if project.path.toString != "" then - File(project.path).delete() - - private def removeProjectFromList(name: String): Option[ProjectType] = - workspace.projects.zipWithIndex - .find { case (record, _) => - record.name == name - } - .map(_._2) - .map { index => - workspace.projects.remove(index) - } - - // Kept for backward compatibility - @deprecated("", "") - def recordExists(inputPath: String): Boolean = projectExists(inputPath) - - @deprecated("", "") - def baseCpgExists(inputPath: String, isLegacy: Boolean = false): Boolean = - cpgExists(inputPath, isLegacy) + /** Retrieve the currently active project. If no project is active, None is returned. + */ + def getActiveProject: Option[Project] = + workspace.projects.lastOption + + /** Open project by name and return it. If a project with this name does not exist, None is + * returned. If the CPG of this project is loaded, it is unloaded first and then reloaded. + * Returns project or None on error. + * + * @param name + * of the project to load + * @param loader + * function to perform CPG loading. This parameter only exists for testing purposes. + */ + def openProject( + name: String, + loader: String => Option[Cpg] = { x => + loadCpgRaw(x) + } + ): Option[Project] = + if !projectExists(name) then + report( + s"Project does not exist in workspace. Try `importCode/importAtom(inputPath)` to create it" + ) + None + else if !File(baseCpgFilename(name)).exists then + report(s"CPG for project $name does not exist at ${baseCpgFilename(name)}, bailing out") + None + else if project(name).exists(_.cpg.isDefined) then + setActiveProject(name) + project(name) + else + val cpgFilename = baseCpgFilename(name) + val cpgFile = File(cpgFilename) + val workingCopyPath = projectDir(name).resolve(Project.workCpgFileName) + val workingCopyName = workingCopyPath.toAbsolutePath.toString + cp(cpgFile, workingCopyPath) + + val result = + val newCpg = loader(workingCopyName) + val projectPath = File(workingCopyName).parent.path + newCpg.flatMap { c => + unloadCpgIfExists(name) + setCpgForProject(c, projectPath) + projectByCpg(c) + } + result + + /** Free up resources occupied by this project but do not remove project from disk. + */ + def closeProject(name: String): Option[Project] = + projectByName(name).map(_.close) + + /** Set CPG for existing project. It is assumed that the CPG is loaded. + */ + private def setCpgForProject(newCpg: Cpg, projectPath: Path): Unit = + val project = getProjectByPath(projectPath) + project match + case Some(p) => + p.cpg = Some(newCpg) + setActiveProject(p.name) + case None => + System.err.println( + s"Error setting CPG for non-existing/unloaded project at $projectPath" + ) + + private def loadCpgRaw(cpgFilename: String): Option[Cpg] = + Try { + val odbConfig = Config.withDefaults.withStorageLocation(cpgFilename) + val config = + CpgLoaderConfig.withDefaults.doNotCreateIndexesOnLoad.withOverflowConfig(odbConfig) + val newCpg = CpgLoader.loadFromOverflowDb(config) + CpgLoader.createIndexes(newCpg) + newCpg + } match + case Success(v) => Some(v) + case Failure(ex) => + System.err.println("Error loading CPG") + System.err.println(ex) + None + + private def addProjectToProjectsList(project: ProjectType): ListBuffer[ProjectType] = + workspace.projects += project + + def unloadCpgByProjectName(name: String): Unit = + projectByName(name).foreach { record => + record.cpg.foreach(_.close) + record.cpg = None + } + + def reloadCpgByName(name: String, loadCpg: String => Option[Cpg]): Option[Cpg] = + projectByName(name).flatMap { record => + record.cpg = loadCpg(record.name) + record.cpg + } + + private def unloadCpgIfExists(name: String): Unit = + projectByName(File(projectDir(name)).name) + .flatMap(_.cpg) + .foreach { c => + try + c.close + catch + case _: IllegalStateException => // Store is already closed + } + + /** Remove currently active project from workspace and delete all associated workspace files from + * disk. + */ + def deleteCurrentProject(): Unit = + val project = projectByCpg(cpg) + project match + case Some(p) => deleteProject(p) + case None => + report(s"Project for active CPG does not exist") + + /** Remove project with name `name` from workspace and delete all associated workspace files from + * disk. + * @param name + * the name of the project that should be removed + */ + def deleteProject(name: String): Option[Unit] = + val project = projectByName(name) + project match + case Some(p) => + deleteProject(p) + Option[Unit](()) + case None => + report(s"Project with name $name does not exist") + None + + private def deleteProject(project: Project): Unit = + removeProjectFromList(project.name) + if project.path.toString != "" then + File(project.path).delete() + + private def removeProjectFromList(name: String): Option[ProjectType] = + workspace.projects.zipWithIndex + .find { case (record, _) => + record.name == name + } + .map(_._2) + .map { index => + workspace.projects.remove(index) + } + + // Kept for backward compatibility + @deprecated("", "") + def recordExists(inputPath: String): Boolean = projectExists(inputPath) + + @deprecated("", "") + def baseCpgExists(inputPath: String, isLegacy: Boolean = false): Boolean = + cpgExists(inputPath, isLegacy) end WorkspaceManager object WorkspaceManager: - private val BASE_CPG_FILENAME = "cpg.bin" + private val BASE_CPG_FILENAME = "cpg.bin" - def overlayFilesForDir(dirName: String): List[File] = - File(dirName).list - .filter(f => f.isRegularFile && f.name != BASE_CPG_FILENAME) - .toList - .sortBy(_.name) + def overlayFilesForDir(dirName: String): List[File] = + File(dirName).list + .filter(f => f.isRegularFile && f.name != BASE_CPG_FILENAME) + .toList + .sortBy(_.name) diff --git a/console/src/test/scala/io/appthreat/console/ConsoleTests.scala b/console/src/test/scala/io/appthreat/console/ConsoleTests.scala index 25f98796..165399ca 100644 --- a/console/src/test/scala/io/appthreat/console/ConsoleTests.scala +++ b/console/src/test/scala/io/appthreat/console/ConsoleTests.scala @@ -1,7 +1,7 @@ package io.appthreat.console -import better.files.Dsl._ -import better.files._ +import better.files.Dsl.* +import better.files.* import io.appthreat.console.testing._ import io.appthreat.x2cpg.X2Cpg.defaultOverlayCreators import io.appthreat.x2cpg.layers.{Base, CallGraph, ControlFlow, TypeRelations} diff --git a/console/src/test/scala/io/appthreat/console/LanguageHelperTests.scala b/console/src/test/scala/io/appthreat/console/LanguageHelperTests.scala index d06d22b1..ce2994be 100644 --- a/console/src/test/scala/io/appthreat/console/LanguageHelperTests.scala +++ b/console/src/test/scala/io/appthreat/console/LanguageHelperTests.scala @@ -1,7 +1,7 @@ package io.appthreat.console -import better.files.Dsl._ -import better.files._ +import better.files.Dsl.* +import better.files.* import io.shiftleft.codepropertygraph.generated.Languages import io.appthreat.console.cpgcreation.guessLanguage import org.scalatest.matchers.should.Matchers diff --git a/console/src/test/scala/io/appthreat/console/PluginManagerTests.scala b/console/src/test/scala/io/appthreat/console/PluginManagerTests.scala index e0fdd8b0..41c5bc2f 100644 --- a/console/src/test/scala/io/appthreat/console/PluginManagerTests.scala +++ b/console/src/test/scala/io/appthreat/console/PluginManagerTests.scala @@ -1,7 +1,7 @@ package io.appthreat.console -import better.files.Dsl._ -import better.files._ +import better.files.Dsl.* +import better.files.* import io.shiftleft.utils.ProjectRoot import org.scalatest.matchers.should.Matchers import org.scalatest.wordspec.AnyWordSpec diff --git a/console/src/test/scala/io/appthreat/console/testing/package.scala b/console/src/test/scala/io/appthreat/console/testing/package.scala index 8bf64139..f0287dec 100644 --- a/console/src/test/scala/io/appthreat/console/testing/package.scala +++ b/console/src/test/scala/io/appthreat/console/testing/package.scala @@ -1,7 +1,7 @@ package io.appthreat.console -import better.files.Dsl._ -import better.files._ +import better.files.Dsl.* +import better.files.* import io.appthreat.console.workspacehandling.Project import scala.util.Try diff --git a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/DefaultSemantics.scala b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/DefaultSemantics.scala index 8265e558..1621f5f4 100644 --- a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/DefaultSemantics.scala +++ b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/DefaultSemantics.scala @@ -7,176 +7,176 @@ import scala.annotation.unused object DefaultSemantics: - /** @return - * a default set of common external procedure calls for all languages. - */ - def apply(): Semantics = - val list = operatorFlows ++ cFlows ++ javaFlows - Semantics.fromList(list) + /** @return + * a default set of common external procedure calls for all languages. + */ + def apply(): Semantics = + val list = operatorFlows ++ cFlows ++ javaFlows + Semantics.fromList(list) - private def F = (x: String, y: List[(Int, Int)]) => FlowSemantic.from(x, y) + private def F = (x: String, y: List[(Int, Int)]) => FlowSemantic.from(x, y) - private def PTF(x: String, ys: List[(Int, Int)] = List.empty): FlowSemantic = - FlowSemantic(x).copy(mappings = FlowSemantic.from(x, ys).mappings :+ PassThroughMapping) + private def PTF(x: String, ys: List[(Int, Int)] = List.empty): FlowSemantic = + FlowSemantic(x).copy(mappings = FlowSemantic.from(x, ys).mappings :+ PassThroughMapping) - def operatorFlows: List[FlowSemantic] = List( - F(Operators.addition, List((1, -1), (2, -1))), - F(Operators.addressOf, List((1, -1))), - F(Operators.assignment, List((2, 1), (2, -1))), - F(Operators.assignmentAnd, List((2, 1), (1, 1), (2, -1))), - F(Operators.assignmentArithmeticShiftRight, List((2, 1), (1, 1), (2, -1))), - F(Operators.assignmentDivision, List((2, 1), (1, 1), (2, -1))), - F(Operators.assignmentExponentiation, List((2, 1), (1, 1), (2, -1))), - F(Operators.assignmentLogicalShiftRight, List((2, 1), (1, 1), (2, -1))), - F(Operators.assignmentMinus, List((2, 1), (1, 1), (2, -1))), - F(Operators.assignmentModulo, List((2, 1), (1, 1), (2, -1))), - F(Operators.assignmentMultiplication, List((2, 1), (1, 1), (2, -1))), - F(Operators.assignmentOr, List((2, 1), (1, 1), (2, -1))), - F(Operators.assignmentPlus, List((2, 1), (1, 1), (2, -1))), - F(Operators.assignmentShiftLeft, List((2, 1), (1, 1), (2, -1))), - F(Operators.assignmentXor, List((2, 1), (1, 1), (2, -1))), - F(Operators.cast, List((1, -1), (2, -1))), - F(Operators.computedMemberAccess, List((1, -1))), - F(Operators.conditional, List((2, -1), (3, -1))), - F(Operators.elvis, List((1, -1), (2, -1))), - F(Operators.notNullAssert, List((1, -1))), - F(Operators.fieldAccess, List((1, -1))), - F(Operators.getElementPtr, List((1, -1))), + def operatorFlows: List[FlowSemantic] = List( + F(Operators.addition, List((1, -1), (2, -1))), + F(Operators.addressOf, List((1, -1))), + F(Operators.assignment, List((2, 1), (2, -1))), + F(Operators.assignmentAnd, List((2, 1), (1, 1), (2, -1))), + F(Operators.assignmentArithmeticShiftRight, List((2, 1), (1, 1), (2, -1))), + F(Operators.assignmentDivision, List((2, 1), (1, 1), (2, -1))), + F(Operators.assignmentExponentiation, List((2, 1), (1, 1), (2, -1))), + F(Operators.assignmentLogicalShiftRight, List((2, 1), (1, 1), (2, -1))), + F(Operators.assignmentMinus, List((2, 1), (1, 1), (2, -1))), + F(Operators.assignmentModulo, List((2, 1), (1, 1), (2, -1))), + F(Operators.assignmentMultiplication, List((2, 1), (1, 1), (2, -1))), + F(Operators.assignmentOr, List((2, 1), (1, 1), (2, -1))), + F(Operators.assignmentPlus, List((2, 1), (1, 1), (2, -1))), + F(Operators.assignmentShiftLeft, List((2, 1), (1, 1), (2, -1))), + F(Operators.assignmentXor, List((2, 1), (1, 1), (2, -1))), + F(Operators.cast, List((1, -1), (2, -1))), + F(Operators.computedMemberAccess, List((1, -1))), + F(Operators.conditional, List((2, -1), (3, -1))), + F(Operators.elvis, List((1, -1), (2, -1))), + F(Operators.notNullAssert, List((1, -1))), + F(Operators.fieldAccess, List((1, -1))), + F(Operators.getElementPtr, List((1, -1))), - // TODO does this still exist? - F(".incBy", List((1, 1), (2, 1), (3, 1), (4, 1))), - F(Operators.indexAccess, List((1, -1))), - F(Operators.indirectComputedMemberAccess, List((1, -1))), - F(Operators.indirectFieldAccess, List((1, -1))), - F(Operators.indirectIndexAccess, List((1, -1), (2, 1))), - F(Operators.indirectMemberAccess, List((1, -1))), - F(Operators.indirection, List((1, -1))), - F(Operators.memberAccess, List((1, -1))), - F(Operators.pointerShift, List((1, -1))), - F(Operators.postDecrement, List((1, 1), (1, -1))), - F(Operators.postIncrement, List((1, 1), (1, -1))), - F(Operators.preDecrement, List((1, 1), (1, -1))), - F(Operators.preIncrement, List((1, 1), (1, -1))), - F(Operators.sizeOf, List.empty[(Int, Int)]), + // TODO does this still exist? + F(".incBy", List((1, 1), (2, 1), (3, 1), (4, 1))), + F(Operators.indexAccess, List((1, -1))), + F(Operators.indirectComputedMemberAccess, List((1, -1))), + F(Operators.indirectFieldAccess, List((1, -1))), + F(Operators.indirectIndexAccess, List((1, -1), (2, 1))), + F(Operators.indirectMemberAccess, List((1, -1))), + F(Operators.indirection, List((1, -1))), + F(Operators.memberAccess, List((1, -1))), + F(Operators.pointerShift, List((1, -1))), + F(Operators.postDecrement, List((1, 1), (1, -1))), + F(Operators.postIncrement, List((1, 1), (1, -1))), + F(Operators.preDecrement, List((1, 1), (1, -1))), + F(Operators.preIncrement, List((1, 1), (1, -1))), + F(Operators.sizeOf, List.empty[(Int, Int)]), - // some of those operators have duplicate mappings due to a typo - // - see https://github.com/ShiftLeftSecurity/codepropertygraph/pull/1630 + // some of those operators have duplicate mappings due to a typo + // - see https://github.com/ShiftLeftSecurity/codepropertygraph/pull/1630 - F(".assignmentExponentiation", List((2, 1), (1, 1))), - F(".assignmentModulo", List((2, 1), (1, 1))), - F(".assignmentShiftLeft", List((2, 1), (1, 1))), - F(".assignmentLogicalShiftRight", List((2, 1), (1, 1))), - F(".assignmentArithmeticShiftRight", List((2, 1), (1, 1))), - F(".assignmentAnd", List((2, 1), (1, 1))), - F(".assignmentOr", List((2, 1), (1, 1))), - F(".assignmentXor", List((2, 1), (1, 1))), + F(".assignmentExponentiation", List((2, 1), (1, 1))), + F(".assignmentModulo", List((2, 1), (1, 1))), + F(".assignmentShiftLeft", List((2, 1), (1, 1))), + F(".assignmentLogicalShiftRight", List((2, 1), (1, 1))), + F(".assignmentArithmeticShiftRight", List((2, 1), (1, 1))), + F(".assignmentAnd", List((2, 1), (1, 1))), + F(".assignmentOr", List((2, 1), (1, 1))), + F(".assignmentXor", List((2, 1), (1, 1))), - // Language specific operators - PTF(".tupleLiteral"), - PTF(".dictLiteral"), - PTF(".setLiteral"), - PTF(".listLiteral") - ) + // Language specific operators + PTF(".tupleLiteral"), + PTF(".dictLiteral"), + PTF(".setLiteral"), + PTF(".listLiteral") + ) - /** Semantic summaries for common external C/C++ calls. - * - * @see - * Standard - * C Library Functions - */ - def cFlows: List[FlowSemantic] = List( - F("abs", List((1, 1), (1, -1))), - F("abort", List.empty[(Int, Int)]), - F("asctime", List((1, 1), (1, -1))), - F("asctime_r", List((1, 1), (1, -1))), - F("atof", List((1, 1), (1, -1))), - F("atoi", List((1, 1), (1, -1))), - F("atol", List((1, 1), (1, -1))), - F("calloc", List((1, -1), (2, -1))), - F("ceil", List((1, 1), (1, 1))), - F("clock", List.empty[(Int, Int)]), - F("ctime", List((1, -1))), - F("ctime64", List((1, -1))), - F("ctime_r", List((1, -1))), - F("ctime64_r", List((1, -1))), - F("difftime", List((1, -1), (2, -1))), - F("difftime64", List((1, -1), (2, -1))), - PTF("div"), - F("exit", List((1, 1))), - F("exp", List((1, -1))), - F("fabs", List((1, -1))), - F("fclose", List((1, 1), (1, -1))), - F("fdopen", List((1, -1), (2, -1))), - F("feof", List((1, 1), (1, -1))), - F("ferror", List((1, 1), (1, -1))), - F("fflush", List((1, 1), (1, -1))), - F("fgetc", List((1, 1), (1, -1))), - F("fwrite", List((1, 1), (1, -1), (2, -1), (3, -1), (4, -1))), - F("free", List((1, 1))), - F("getc", List((1, 1))), - F("scanf", List((2, 2))), - F("strcmp", List((1, 1), (1, -1), (2, 2), (2, -1))), - F("strlen", List((1, 1), (1, -1))), - F("strncpy", List((1, 1), (2, 2), (3, 3), (1, -1), (2, -1))), - F("strncat", List((1, 1), (1, -1), (2, 2), (2, -1))) - ) + /** Semantic summaries for common external C/C++ calls. + * + * @see + * Standard + * C Library Functions + */ + def cFlows: List[FlowSemantic] = List( + F("abs", List((1, 1), (1, -1))), + F("abort", List.empty[(Int, Int)]), + F("asctime", List((1, 1), (1, -1))), + F("asctime_r", List((1, 1), (1, -1))), + F("atof", List((1, 1), (1, -1))), + F("atoi", List((1, 1), (1, -1))), + F("atol", List((1, 1), (1, -1))), + F("calloc", List((1, -1), (2, -1))), + F("ceil", List((1, 1), (1, 1))), + F("clock", List.empty[(Int, Int)]), + F("ctime", List((1, -1))), + F("ctime64", List((1, -1))), + F("ctime_r", List((1, -1))), + F("ctime64_r", List((1, -1))), + F("difftime", List((1, -1), (2, -1))), + F("difftime64", List((1, -1), (2, -1))), + PTF("div"), + F("exit", List((1, 1))), + F("exp", List((1, -1))), + F("fabs", List((1, -1))), + F("fclose", List((1, 1), (1, -1))), + F("fdopen", List((1, -1), (2, -1))), + F("feof", List((1, 1), (1, -1))), + F("ferror", List((1, 1), (1, -1))), + F("fflush", List((1, 1), (1, -1))), + F("fgetc", List((1, 1), (1, -1))), + F("fwrite", List((1, 1), (1, -1), (2, -1), (3, -1), (4, -1))), + F("free", List((1, 1))), + F("getc", List((1, 1))), + F("scanf", List((2, 2))), + F("strcmp", List((1, 1), (1, -1), (2, 2), (2, -1))), + F("strlen", List((1, 1), (1, -1))), + F("strncpy", List((1, 1), (2, 2), (3, 3), (1, -1), (2, -1))), + F("strncat", List((1, 1), (1, -1), (2, 2), (2, -1))) + ) - /** Semantic summaries for common external Java calls. - */ - def javaFlows: List[FlowSemantic] = List( - PTF("java.lang.String.split:java.lang.String[](java.lang.String)", List((0, 0))), - PTF("java.lang.String.split:java.lang.String[](java.lang.String,int)", List((0, 0))), - PTF("java.lang.String.compareTo:int(java.lang.String)", List((0, 0))), - F("java.io.PrintWriter.print:void(java.lang.String)", List((0, 0), (1, 1))), - F("java.io.PrintWriter.println:void(java.lang.String)", List((0, 0), (1, 1))), - F("java.io.PrintStream.println:void(java.lang.String)", List((0, 0), (1, 1))), - PTF("java.io.PrintStream.print:void(java.lang.String)", List((0, 0))), - F("android.text.TextUtils.isEmpty:boolean(java.lang.String)", List((0, -1), (1, -1))), - F( - "java.sql.PreparedStatement.prepareStatement:java.sql.PreparedStatement(java.lang.String)", - List((1, -1)) - ), - F("java.sql.PreparedStatement.prepareStatement:setDouble(int,double)", List((1, 1), (2, 2))), - F("java.sql.PreparedStatement.prepareStatement:setFloat(int,float)", List((1, 1), (2, 2))), - F("java.sql.PreparedStatement.prepareStatement:setInt(int,int)", List((1, 1), (2, 2))), - F("java.sql.PreparedStatement.prepareStatement:setLong(int,long)", List((1, 1), (2, 2))), - F("java.sql.PreparedStatement.prepareStatement:setShort(int,short)", List((1, 1), (2, 2))), - F( - "java.sql.PreparedStatement.prepareStatement:setString(int,java.lang.String)", - List((1, 1), (2, 2)) - ), - F( - "org.apache.http.HttpRequest.:void(org.apache.http.RequestLine)", - List((1, 1), (1, 0)) - ), - F( - "org.apache.http.HttpRequest.:void(java.lang.String,java.lang.String)", - List((1, 1), (1, 0), (2, 0)) - ), - F( - "org.apache.http.HttpRequest.:void(java.lang.String,java.lang.String,org.apache.http.ProtocolVersion)", - List((1, 1), (1, 0), (2, 2), (2, 0), (3, 3), (3, 0)) - ), - F("org.apache.http.HttpResponse.getStatusLine:org.apache.http.StatusLine()", List((0, -1))), - F( - "org.apache.http.HttpResponse.setStatusLine:void(org.apache.http.StatusLine)", - List((1, 0), (1, 1), (0, -1)) - ), - F( - "org.apache.http.HttpResponse.setReasonPhrase:void(java.lang.String)", - List((1, 0), (1, 1), (0, -1)) - ), - F("org.apache.http.HttpResponse.getEntity:org.apache.http.HttpEntity()", List((0, -1))), - F( - "org.apache.http.HttpResponse.setEntity:void(org.apache.http.HttpEntity)", - List((1, 0), (1, 1), (1, 0)) - ) + /** Semantic summaries for common external Java calls. + */ + def javaFlows: List[FlowSemantic] = List( + PTF("java.lang.String.split:java.lang.String[](java.lang.String)", List((0, 0))), + PTF("java.lang.String.split:java.lang.String[](java.lang.String,int)", List((0, 0))), + PTF("java.lang.String.compareTo:int(java.lang.String)", List((0, 0))), + F("java.io.PrintWriter.print:void(java.lang.String)", List((0, 0), (1, 1))), + F("java.io.PrintWriter.println:void(java.lang.String)", List((0, 0), (1, 1))), + F("java.io.PrintStream.println:void(java.lang.String)", List((0, 0), (1, 1))), + PTF("java.io.PrintStream.print:void(java.lang.String)", List((0, 0))), + F("android.text.TextUtils.isEmpty:boolean(java.lang.String)", List((0, -1), (1, -1))), + F( + "java.sql.PreparedStatement.prepareStatement:java.sql.PreparedStatement(java.lang.String)", + List((1, -1)) + ), + F("java.sql.PreparedStatement.prepareStatement:setDouble(int,double)", List((1, 1), (2, 2))), + F("java.sql.PreparedStatement.prepareStatement:setFloat(int,float)", List((1, 1), (2, 2))), + F("java.sql.PreparedStatement.prepareStatement:setInt(int,int)", List((1, 1), (2, 2))), + F("java.sql.PreparedStatement.prepareStatement:setLong(int,long)", List((1, 1), (2, 2))), + F("java.sql.PreparedStatement.prepareStatement:setShort(int,short)", List((1, 1), (2, 2))), + F( + "java.sql.PreparedStatement.prepareStatement:setString(int,java.lang.String)", + List((1, 1), (2, 2)) + ), + F( + "org.apache.http.HttpRequest.:void(org.apache.http.RequestLine)", + List((1, 1), (1, 0)) + ), + F( + "org.apache.http.HttpRequest.:void(java.lang.String,java.lang.String)", + List((1, 1), (1, 0), (2, 0)) + ), + F( + "org.apache.http.HttpRequest.:void(java.lang.String,java.lang.String,org.apache.http.ProtocolVersion)", + List((1, 1), (1, 0), (2, 2), (2, 0), (3, 3), (3, 0)) + ), + F("org.apache.http.HttpResponse.getStatusLine:org.apache.http.StatusLine()", List((0, -1))), + F( + "org.apache.http.HttpResponse.setStatusLine:void(org.apache.http.StatusLine)", + List((1, 0), (1, 1), (0, -1)) + ), + F( + "org.apache.http.HttpResponse.setReasonPhrase:void(java.lang.String)", + List((1, 0), (1, 1), (0, -1)) + ), + F("org.apache.http.HttpResponse.getEntity:org.apache.http.HttpEntity()", List((0, -1))), + F( + "org.apache.http.HttpResponse.setEntity:void(org.apache.http.HttpEntity)", + List((1, 0), (1, 1), (1, 0)) ) + ) - /** @return - * procedure semantics for operators and common external Java calls only. - */ - @unused - def javaSemantics(): Semantics = Semantics.fromList(operatorFlows ++ javaFlows) + /** @return + * procedure semantics for operators and common external Java calls only. + */ + @unused + def javaSemantics(): Semantics = Semantics.fromList(operatorFlows ++ javaFlows) end DefaultSemantics diff --git a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/dotgenerator/DdgGenerator.scala b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/dotgenerator/DdgGenerator.scala index 2ae58538..343f80f0 100644 --- a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/dotgenerator/DdgGenerator.scala +++ b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/dotgenerator/DdgGenerator.scala @@ -15,111 +15,111 @@ import scala.collection.mutable class DdgGenerator: - val edgeType = "DDG" - private val edgeCache = mutable.Map[StoredNode, List[Edge]]() - - def generate(methodNode: Method)(implicit semantics: Semantics = DefaultSemantics()): Graph = - val entryNode = methodNode - val paramNodes = methodNode.parameter.l - val allOtherNodes = methodNode.cfgNode.l - val exitNode = methodNode.methodReturn - val allNodes: List[StoredNode] = List(entryNode, exitNode) ++ paramNodes ++ allOtherNodes - val visibleNodes = allNodes.filter(shouldBeDisplayed) - - val edges = visibleNodes.map { dstNode => - inEdgesToDisplay(dstNode) + val edgeType = "DDG" + private val edgeCache = mutable.Map[StoredNode, List[Edge]]() + + def generate(methodNode: Method)(implicit semantics: Semantics = DefaultSemantics()): Graph = + val entryNode = methodNode + val paramNodes = methodNode.parameter.l + val allOtherNodes = methodNode.cfgNode.l + val exitNode = methodNode.methodReturn + val allNodes: List[StoredNode] = List(entryNode, exitNode) ++ paramNodes ++ allOtherNodes + val visibleNodes = allNodes.filter(shouldBeDisplayed) + + val edges = visibleNodes.map { dstNode => + inEdgesToDisplay(dstNode) + } + + val allIdsReferencedByEdges = edges.flatten.flatMap { edge => + Set(edge.src.id, edge.dst.id) + } + + val ddgNodes = visibleNodes + .filter(node => allIdsReferencedByEdges.contains(node.id)) + .map(surroundingCall) + .filterNot(node => + node.isInstanceOf[Call] && isGenericMemberAccessName(node.asInstanceOf[Call].name) + ) + + val ddgEdges = edges.flatten + .map { edge => + edge.copy(src = surroundingCall(edge.src), dst = surroundingCall(edge.dst)) } - - val allIdsReferencedByEdges = edges.flatten.flatMap { edge => - Set(edge.src.id, edge.dst.id) - } - - val ddgNodes = visibleNodes - .filter(node => allIdsReferencedByEdges.contains(node.id)) - .map(surroundingCall) - .filterNot(node => - node.isInstanceOf[Call] && isGenericMemberAccessName(node.asInstanceOf[Call].name) - ) - - val ddgEdges = edges.flatten - .map { edge => - edge.copy(src = surroundingCall(edge.src), dst = surroundingCall(edge.dst)) - } - .filter(e => e.src != e.dst) - .filterNot(e => - e.dst.isInstanceOf[Call] && isGenericMemberAccessName(e.dst.asInstanceOf[Call].name) - ) - .filterNot(e => - e.src.isInstanceOf[Call] && isGenericMemberAccessName(e.src.asInstanceOf[Call].name) + .filter(e => e.src != e.dst) + .filterNot(e => + e.dst.isInstanceOf[Call] && isGenericMemberAccessName(e.dst.asInstanceOf[Call].name) + ) + .filterNot(e => + e.src.isInstanceOf[Call] && isGenericMemberAccessName(e.src.asInstanceOf[Call].name) + ) + .distinct + + edgeCache.clear() + Graph(ddgNodes, ddgEdges) + end generate + + private def surroundingCall(node: StoredNode): StoredNode = + node match + case arg: Expression => arg.inCall.headOption.getOrElse(node) + case _ => node + + private def shouldBeDisplayed(v: Node): Boolean = !( + v.isInstanceOf[ControlStructure] || + v.isInstanceOf[JumpTarget] + ) + + private def inEdgesToDisplay(dstNode: StoredNode, visited: List[StoredNode] = List())(implicit + semantics: Semantics + ): List[Edge] = + + if edgeCache.contains(dstNode) then + return edgeCache(dstNode) + + if visited.contains(dstNode) then + List() + else + val parents = expand(dstNode) + val (visible, invisible) = + parents.partition(x => shouldBeDisplayed(x.src) && x.srcVisible) + val result = visible.toList ++ invisible.toList.flatMap { n => + val parentInEdgesToDisplay = inEdgesToDisplay(n.src, visited ++ List(dstNode)) + parentInEdgesToDisplay.map(y => + Edge(y.src, dstNode, y.srcVisible, edgeType = edgeType, label = y.label) + ) + }.distinct + edgeCache.put(dstNode, result) + result + end inEdgesToDisplay + + private def expand(v: StoredNode)(implicit semantics: Semantics): Iterator[Edge] = + + val allInEdges = v + .inE(EdgeTypes.REACHING_DEF) + .map(x => + Edge( + x.outNode.asInstanceOf[StoredNode], + v, + srcVisible = true, + x.property(Properties.VARIABLE), + edgeType ) - .distinct - - edgeCache.clear() - Graph(ddgNodes, ddgEdges) - end generate - - private def surroundingCall(node: StoredNode): StoredNode = - node match - case arg: Expression => arg.inCall.headOption.getOrElse(node) - case _ => node - - private def shouldBeDisplayed(v: Node): Boolean = !( - v.isInstanceOf[ControlStructure] || - v.isInstanceOf[JumpTarget] - ) - - private def inEdgesToDisplay(dstNode: StoredNode, visited: List[StoredNode] = List())(implicit - semantics: Semantics - ): List[Edge] = - - if edgeCache.contains(dstNode) then - return edgeCache(dstNode) - - if visited.contains(dstNode) then - List() - else - val parents = expand(dstNode) - val (visible, invisible) = - parents.partition(x => shouldBeDisplayed(x.src) && x.srcVisible) - val result = visible.toList ++ invisible.toList.flatMap { n => - val parentInEdgesToDisplay = inEdgesToDisplay(n.src, visited ++ List(dstNode)) - parentInEdgesToDisplay.map(y => - Edge(y.src, dstNode, y.srcVisible, edgeType = edgeType, label = y.label) - ) - }.distinct - edgeCache.put(dstNode, result) - result - end inEdgesToDisplay - - private def expand(v: StoredNode)(implicit semantics: Semantics): Iterator[Edge] = - - val allInEdges = v - .inE(EdgeTypes.REACHING_DEF) - .map(x => - Edge( - x.outNode.asInstanceOf[StoredNode], - v, - srcVisible = true, - x.property(Properties.VARIABLE), - edgeType - ) - ) - - v match - case cfgNode: CfgNode => - cfgNode - .ddgInPathElem(withInvisible = true) - .map(x => - Edge( - x.node.asInstanceOf[StoredNode], - v, - x.visible, - x.outEdgeLabel, - edgeType - ) - ) - .iterator ++ allInEdges.filter(_.src.isInstanceOf[Method]).iterator - case _ => - allInEdges.iterator - end expand + ) + + v match + case cfgNode: CfgNode => + cfgNode + .ddgInPathElem(withInvisible = true) + .map(x => + Edge( + x.node.asInstanceOf[StoredNode], + v, + x.visible, + x.outEdgeLabel, + edgeType + ) + ) + .iterator ++ allInEdges.filter(_.src.isInstanceOf[Method]).iterator + case _ => + allInEdges.iterator + end expand end DdgGenerator diff --git a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/dotgenerator/DotCpg14Generator.scala b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/dotgenerator/DotCpg14Generator.scala index 685e3dbd..bcf1c2af 100644 --- a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/dotgenerator/DotCpg14Generator.scala +++ b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/dotgenerator/DotCpg14Generator.scala @@ -12,14 +12,14 @@ import io.shiftleft.semanticcpg.dotgenerator.{ object DotCpg14Generator: - def toDotCpg14(traversal: Iterator[Method])(implicit - semantics: Semantics = DefaultSemantics() - ): Iterator[String] = - traversal.map(dotGraphForMethod) + def toDotCpg14(traversal: Iterator[Method])(implicit + semantics: Semantics = DefaultSemantics() + ): Iterator[String] = + traversal.map(dotGraphForMethod) - private def dotGraphForMethod(method: Method)(implicit semantics: Semantics): String = - val ast = new AstGenerator().generate(method) - val cfg = new CfgGenerator().generate(method) - val ddg = new DdgGenerator().generate(method) - val cdg = new CdgGenerator().generate(method) - DotSerializer.dotGraph(Option(method), ast ++ cfg ++ ddg ++ cdg, withEdgeTypes = true) + private def dotGraphForMethod(method: Method)(implicit semantics: Semantics): String = + val ast = new AstGenerator().generate(method) + val cfg = new CfgGenerator().generate(method) + val ddg = new DdgGenerator().generate(method) + val cdg = new CdgGenerator().generate(method) + DotSerializer.dotGraph(Option(method), ast ++ cfg ++ ddg ++ cdg, withEdgeTypes = true) diff --git a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/dotgenerator/DotDdgGenerator.scala b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/dotgenerator/DotDdgGenerator.scala index 2c5ffdef..cede5b9e 100644 --- a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/dotgenerator/DotDdgGenerator.scala +++ b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/dotgenerator/DotDdgGenerator.scala @@ -7,12 +7,12 @@ import io.shiftleft.semanticcpg.dotgenerator.DotSerializer object DotDdgGenerator: - def toDotDdg(traversal: Iterator[Method])(implicit - semantics: Semantics = DefaultSemantics() - ): Iterator[String] = - traversal.map(dotGraphForMethod) + def toDotDdg(traversal: Iterator[Method])(implicit + semantics: Semantics = DefaultSemantics() + ): Iterator[String] = + traversal.map(dotGraphForMethod) - private def dotGraphForMethod(method: Method)(implicit semantics: Semantics): String = - val ddgGenerator = new DdgGenerator() - val ddg = ddgGenerator.generate(method) - DotSerializer.dotGraph(Option(method), ddg) + private def dotGraphForMethod(method: Method)(implicit semantics: Semantics): String = + val ddgGenerator = new DdgGenerator() + val ddg = ddgGenerator.generate(method) + DotSerializer.dotGraph(Option(method), ddg) diff --git a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/dotgenerator/DotPdgGenerator.scala b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/dotgenerator/DotPdgGenerator.scala index 78818e10..1861c7a3 100644 --- a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/dotgenerator/DotPdgGenerator.scala +++ b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/dotgenerator/DotPdgGenerator.scala @@ -7,12 +7,12 @@ import io.shiftleft.semanticcpg.dotgenerator.{CdgGenerator, DotSerializer} object DotPdgGenerator: - def toDotPdg(traversal: Iterator[Method])(implicit - semantics: Semantics = DefaultSemantics() - ): Iterator[String] = - traversal.map(dotGraphForMethod) + def toDotPdg(traversal: Iterator[Method])(implicit + semantics: Semantics = DefaultSemantics() + ): Iterator[String] = + traversal.map(dotGraphForMethod) - private def dotGraphForMethod(method: Method)(implicit semantics: Semantics): String = - val ddg = new DdgGenerator().generate(method) - val cdg = new CdgGenerator().generate(method) - DotSerializer.dotGraph(Option(method), ddg.++(cdg), withEdgeTypes = true) + private def dotGraphForMethod(method: Method)(implicit semantics: Semantics): String = + val ddg = new DdgGenerator().generate(method) + val cdg = new CdgGenerator().generate(method) + DotSerializer.dotGraph(Option(method), ddg.++(cdg), withEdgeTypes = true) diff --git a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/language/ExtendedCfgNode.scala b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/language/ExtendedCfgNode.scala index f9d0d907..00bb32cc 100644 --- a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/language/ExtendedCfgNode.scala +++ b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/language/ExtendedCfgNode.scala @@ -20,99 +20,99 @@ import scala.collection.parallel.CollectionConverters.* */ class ExtendedCfgNode(val traversal: Iterator[CfgNode]) extends AnyVal: - def ddgIn(implicit semantics: Semantics = DefaultSemantics()): Iterator[CfgNode] = - val cache = mutable.HashMap[CfgNode, Vector[PathElement]]() - val result = - traversal.flatMap(x => x.ddgIn(Vector(PathElement(x)), withInvisible = false, cache)) - cache.clear() - result + def ddgIn(implicit semantics: Semantics = DefaultSemantics()): Iterator[CfgNode] = + val cache = mutable.HashMap[CfgNode, Vector[PathElement]]() + val result = + traversal.flatMap(x => x.ddgIn(Vector(PathElement(x)), withInvisible = false, cache)) + cache.clear() + result - def ddgInPathElem(implicit semantics: Semantics = DefaultSemantics()): Iterator[PathElement] = - val cache = mutable.HashMap[CfgNode, Vector[PathElement]]() - val result = traversal.flatMap(x => - x.ddgInPathElem(Vector(PathElement(x)), withInvisible = false, cache) - ) - cache.clear() - result + def ddgInPathElem(implicit semantics: Semantics = DefaultSemantics()): Iterator[PathElement] = + val cache = mutable.HashMap[CfgNode, Vector[PathElement]]() + val result = traversal.flatMap(x => + x.ddgInPathElem(Vector(PathElement(x)), withInvisible = false, cache) + ) + cache.clear() + result - def reachableBy[NodeType]( - sourceTrav: IterableOnce[NodeType], - sourceTravs: IterableOnce[NodeType]* - )(implicit context: EngineContext): Iterator[NodeType] = - val sources = sourceTravsToStartingPoints(sourceTrav +: sourceTravs*) - val reachedSources = - reachableByInternal(sources).map(_.path.head.node) - reachedSources.cast[NodeType] + def reachableBy[NodeType]( + sourceTrav: IterableOnce[NodeType], + sourceTravs: IterableOnce[NodeType]* + )(implicit context: EngineContext): Iterator[NodeType] = + val sources = sourceTravsToStartingPoints(sourceTrav +: sourceTravs*) + val reachedSources = + reachableByInternal(sources).map(_.path.head.node) + reachedSources.cast[NodeType] - def df[A](sourceTrav: IterableOnce[A], sourceTravs: IterableOnce[A]*)(implicit - context: EngineContext - ): Iterator[Path] = reachableByFlows(sourceTrav, sourceTravs) + def df[A](sourceTrav: IterableOnce[A], sourceTravs: IterableOnce[A]*)(implicit + context: EngineContext + ): Iterator[Path] = reachableByFlows(sourceTrav, sourceTravs) - def reachableByFlows[A](sourceTrav: IterableOnce[A], sourceTravs: IterableOnce[A]*)(implicit - context: EngineContext - ): Iterator[Path] = - val sources = sourceTravsToStartingPoints(sourceTrav +: sourceTravs*) - val startingPoints = sources.map(_.startingPoint) - val paths = reachableByInternal(sources).par - .map { result => - // We can get back results that start in nodes that are invisible - // according to the semantic, e.g., arguments that are only used - // but not defined. We filter these results here prior to returning - val first = result.path.headOption - if first.isDefined && !first.get.visible && !startingPoints.contains(first.get.node) - then - None - else - val visiblePathElements = - result.path.filter(x => startingPoints.contains(x.node) || x.visible) - Some(Path(removeConsecutiveDuplicates(visiblePathElements.map(_.node)))) - } - .filter(_.isDefined) - .dedup - .flatten - .toVector - paths.iterator - end reachableByFlows + def reachableByFlows[A](sourceTrav: IterableOnce[A], sourceTravs: IterableOnce[A]*)(implicit + context: EngineContext + ): Iterator[Path] = + val sources = sourceTravsToStartingPoints(sourceTrav +: sourceTravs*) + val startingPoints = sources.map(_.startingPoint) + val paths = reachableByInternal(sources).par + .map { result => + // We can get back results that start in nodes that are invisible + // according to the semantic, e.g., arguments that are only used + // but not defined. We filter these results here prior to returning + val first = result.path.headOption + if first.isDefined && !first.get.visible && !startingPoints.contains(first.get.node) + then + None + else + val visiblePathElements = + result.path.filter(x => startingPoints.contains(x.node) || x.visible) + Some(Path(removeConsecutiveDuplicates(visiblePathElements.map(_.node)))) + } + .filter(_.isDefined) + .dedup + .flatten + .toVector + paths.iterator + end reachableByFlows - def reachableByDetailed[NodeType]( - sourceTrav: Iterator[NodeType], - sourceTravs: Iterator[NodeType]* - )(implicit context: EngineContext): Vector[TableEntry] = - val sources = - SourcesToStartingPoints.sourceTravsToStartingPoints(sourceTrav +: sourceTravs*) - reachableByInternal(sources) + def reachableByDetailed[NodeType]( + sourceTrav: Iterator[NodeType], + sourceTravs: Iterator[NodeType]* + )(implicit context: EngineContext): Vector[TableEntry] = + val sources = + SourcesToStartingPoints.sourceTravsToStartingPoints(sourceTrav +: sourceTravs*) + reachableByInternal(sources) - private def removeConsecutiveDuplicates[T](l: Vector[T]): List[T] = - l.headOption.map(x => - x :: l.sliding(2).collect { case Seq(a, b) if a != b => b }.toList - ).getOrElse(List()) + private def removeConsecutiveDuplicates[T](l: Vector[T]): List[T] = + l.headOption.map(x => + x :: l.sliding(2).collect { case Seq(a, b) if a != b => b }.toList + ).getOrElse(List()) - private def reachableByInternal( - startingPointsWithSources: List[StartingPointWithSource] - )(implicit context: EngineContext): Vector[TableEntry] = - val sinks = traversal.dedup.toList.sortBy(_.id) - val engine = new Engine(context) - val result = engine.backwards(sinks, startingPointsWithSources.map(_.startingPoint)) + private def reachableByInternal( + startingPointsWithSources: List[StartingPointWithSource] + )(implicit context: EngineContext): Vector[TableEntry] = + val sinks = traversal.dedup.toList.sortBy(_.id) + val engine = new Engine(context) + val result = engine.backwards(sinks, startingPointsWithSources.map(_.startingPoint)) - engine.shutdown() - val sources = startingPointsWithSources.map(_.source) - val startingPointToSource = startingPointsWithSources.map { x => - x.startingPoint.asInstanceOf[AstNode] -> x.source - }.toMap - val res = result.par.map { r => - val startingPoint = r.path.head.node - if sources.contains(startingPoint) || !startingPointToSource( - startingPoint - ).isInstanceOf[AstNode] - then - r - else - r.copy(path = - PathElement( - startingPointToSource(startingPoint).asInstanceOf[AstNode] - ) +: r.path - ) - } - res.toVector - end reachableByInternal + engine.shutdown() + val sources = startingPointsWithSources.map(_.source) + val startingPointToSource = startingPointsWithSources.map { x => + x.startingPoint.asInstanceOf[AstNode] -> x.source + }.toMap + val res = result.par.map { r => + val startingPoint = r.path.head.node + if sources.contains(startingPoint) || !startingPointToSource( + startingPoint + ).isInstanceOf[AstNode] + then + r + else + r.copy(path = + PathElement( + startingPointToSource(startingPoint).asInstanceOf[AstNode] + ) +: r.path + ) + } + res.toVector + end reachableByInternal end ExtendedCfgNode diff --git a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/language/Path.scala b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/language/Path.scala index f1ba3c70..dc86af10 100644 --- a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/language/Path.scala +++ b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/language/Path.scala @@ -8,224 +8,224 @@ import org.apache.commons.lang.StringUtils import scala.collection.mutable.{ArrayBuffer, Set} case class Path(elements: List[AstNode]): - def resultPairs(): List[(String, Option[Integer])] = - val pairs = elements.map { - case point: MethodParameterIn => - val method = point.method - val method_name = method.name - val code = - s"$method_name(${method.parameter.l.sortBy(_.order).map(_.code).mkString(", ")})" - (code, point.lineNumber) - case point => (point.statement.repr, point.lineNumber) - } - pairs.headOption - .map(x => - x :: pairs.sliding(2).collect { - case Seq(a, b) if a._1 != b._1 && a._2 != b._2 => b - }.toList - ) - .getOrElse(List()) + def resultPairs(): List[(String, Option[Integer])] = + val pairs = elements.map { + case point: MethodParameterIn => + val method = point.method + val method_name = method.name + val code = + s"$method_name(${method.parameter.l.sortBy(_.order).map(_.code).mkString(", ")})" + (code, point.lineNumber) + case point => (point.statement.repr, point.lineNumber) + } + pairs.headOption + .map(x => + x :: pairs.sliding(2).collect { + case Seq(a, b) if a._1 != b._1 && a._2 != b._2 => b + }.toList + ) + .getOrElse(List()) object Path: - private val DefaultMaxTrackedWidth = 128 - // TODO replace with dynamic rendering based on the terminal's width, e.g. in scala-repl-pp - private lazy val maxTrackedWidth = - sys.env.get("CHEN_DATAFLOW_TRACKED_WIDTH").map(_.toInt).getOrElse(DefaultMaxTrackedWidth) + private val DefaultMaxTrackedWidth = 128 + // TODO replace with dynamic rendering based on the terminal's width, e.g. in scala-repl-pp + private lazy val maxTrackedWidth = + sys.env.get("CHEN_DATAFLOW_TRACKED_WIDTH").map(_.toInt).getOrElse(DefaultMaxTrackedWidth) - private def tagAsString(tag: Iterator[Tag]) = - if tag.nonEmpty then tag.name.mkString(", ") else "" + private def tagAsString(tag: Iterator[Tag]) = + if tag.nonEmpty then tag.name.mkString(", ") else "" - implicit val show: Show[Path] = path => - var caption = "" - if path.elements.size > 2 then - val srcNode = path.elements.head - val srcTags = tagAsString(srcNode.tag) - val sinkNode = path.elements.last - var sinkCode = sinkNode.code - val sinkTags = tagAsString(sinkNode.tag) - sinkNode match - case cfgNode: CfgNode => - val method = cfgNode.method - sinkCode = method.fullName - caption = if srcNode.code != "this" then s"Source: ${srcNode.code}" else "" - if srcTags.nonEmpty then caption += s"\nSource Tags: ${srcTags}" - caption += s"\nSink: ${sinkCode}\n" - if sinkTags.nonEmpty then caption += s"Sink Tags: ${sinkTags}\n" - var hasCheckLike: Boolean = false; - val tableRows = ArrayBuffer[Array[String]]() - val addedPaths = Set[String]() - path.elements.foreach { astNode => - val nodeType = astNode.getClass.getSimpleName - val lineNumber = astNode.lineNumber.getOrElse("").toString - val fileName = astNode.file.name.headOption.getOrElse("").replace("", "") - var fileLocation = s"${fileName}#${lineNumber}" - var tags: String = tagAsString(astNode.tag) - if fileLocation == "#" then fileLocation = "N/A" - astNode match - case _: MethodReturn => - case methodParameterIn: MethodParameterIn => - val methodName = methodParameterIn.method.name - if tags.isEmpty && methodParameterIn.method.tag.nonEmpty then - tags = tagAsString(methodParameterIn.method.tag) - if tags.isEmpty && methodParameterIn.tag.nonEmpty then - tags = tagAsString(methodParameterIn.tag) - tableRows += Array[String]( - "methodParameterIn", - fileLocation, - methodName, - s"[bold red]${methodParameterIn.name}[/bold red]", - methodParameterIn.method.fullName + (if methodParameterIn.method.isExternal - then " :right_arrow_curving_up:" - else ""), - tags - ) - case ret: Return => - val methodName = ret.method.name - tableRows += Array[String]( - "return", - fileLocation, - methodName, - ret.argumentName.getOrElse(""), - ret.code, - tags - ) - case identifier: Identifier => - val methodName = identifier.method.name - if tags.isEmpty && identifier.inCall.nonEmpty && identifier.inCall.head.tag.nonEmpty - then - tags = tagAsString(identifier.inCall.head.tag) - if !addedPaths.contains( - s"${fileName}#${lineNumber}" - ) && identifier.inCall.nonEmpty - then - tableRows += Array[String]( - "identifier", - fileLocation, - methodName, - identifier.name, - if identifier.inCall.nonEmpty then - identifier.inCall.head.code - else identifier.code, - tags - ) - case member: Member => - val methodName = "" - tableRows += Array[String]( - "member", - fileLocation, - methodName, - nodeType, - member.name, - member.code, - tags - ) - case call: Call => - if !call.code.startsWith(" - val method = cfgNode.method - if tags.isEmpty && method.tag.nonEmpty then - tags = tagAsString(method.tag) - val methodName = method.name - val statement = cfgNode match - case _: MethodParameterIn => - if tags.isEmpty && method.parameter.tag.nonEmpty then - tags = tagAsString(method.parameter.tag) - val paramsPretty = - method.parameter.toList.sortBy(_.index).map(_.code).mkString(", ") - s"$methodName($paramsPretty)" - case _ => - if tags.isEmpty && cfgNode.statement.tag.nonEmpty then - tags = tagAsString(cfgNode.statement.tag) - cfgNode.statement.repr - val tracked = StringUtils.normalizeSpace(StringUtils.abbreviate( - statement, - maxTrackedWidth - )) - tableRows += Array[String]( - "cfgNode", - fileLocation, - methodName, - "", - tracked, - tags + implicit val show: Show[Path] = path => + var caption = "" + if path.elements.size > 2 then + val srcNode = path.elements.head + val srcTags = tagAsString(srcNode.tag) + val sinkNode = path.elements.last + var sinkCode = sinkNode.code + val sinkTags = tagAsString(sinkNode.tag) + sinkNode match + case cfgNode: CfgNode => + val method = cfgNode.method + sinkCode = method.fullName + caption = if srcNode.code != "this" then s"Source: ${srcNode.code}" else "" + if srcTags.nonEmpty then caption += s"\nSource Tags: ${srcTags}" + caption += s"\nSink: ${sinkCode}\n" + if sinkTags.nonEmpty then caption += s"Sink Tags: ${sinkTags}\n" + var hasCheckLike: Boolean = false; + val tableRows = ArrayBuffer[Array[String]]() + val addedPaths = Set[String]() + path.elements.foreach { astNode => + val nodeType = astNode.getClass.getSimpleName + val lineNumber = astNode.lineNumber.getOrElse("").toString + val fileName = astNode.file.name.headOption.getOrElse("").replace("", "") + var fileLocation = s"${fileName}#${lineNumber}" + var tags: String = tagAsString(astNode.tag) + if fileLocation == "#" then fileLocation = "N/A" + astNode match + case _: MethodReturn => + case methodParameterIn: MethodParameterIn => + val methodName = methodParameterIn.method.name + if tags.isEmpty && methodParameterIn.method.tag.nonEmpty then + tags = tagAsString(methodParameterIn.method.tag) + if tags.isEmpty && methodParameterIn.tag.nonEmpty then + tags = tagAsString(methodParameterIn.tag) + tableRows += Array[String]( + "methodParameterIn", + fileLocation, + methodName, + s"[bold red]${methodParameterIn.name}[/bold red]", + methodParameterIn.method.fullName + (if methodParameterIn.method.isExternal + then " :right_arrow_curving_up:" + else ""), + tags + ) + case ret: Return => + val methodName = ret.method.name + tableRows += Array[String]( + "return", + fileLocation, + methodName, + ret.argumentName.getOrElse(""), + ret.code, + tags + ) + case identifier: Identifier => + val methodName = identifier.method.name + if tags.isEmpty && identifier.inCall.nonEmpty && identifier.inCall.head.tag.nonEmpty + then + tags = tagAsString(identifier.inCall.head.tag) + if !addedPaths.contains( + s"${fileName}#${lineNumber}" + ) && identifier.inCall.nonEmpty + then + tableRows += Array[String]( + "identifier", + fileLocation, + methodName, + identifier.name, + if identifier.inCall.nonEmpty then + identifier.inCall.head.code + else identifier.code, + tags + ) + case member: Member => + val methodName = "" + tableRows += Array[String]( + "member", + fileLocation, + methodName, + nodeType, + member.name, + member.code, + tags + ) + case call: Call => + if !call.code.startsWith(" - caption + then " :right_arrow_curving_up:" + else "" + if call.methodFullName.startsWith(" + val method = cfgNode.method + if tags.isEmpty && method.tag.nonEmpty then + tags = tagAsString(method.tag) + val methodName = method.name + val statement = cfgNode match + case _: MethodParameterIn => + if tags.isEmpty && method.parameter.tag.nonEmpty then + tags = tagAsString(method.parameter.tag) + val paramsPretty = + method.parameter.toList.sortBy(_.index).map(_.code).mkString(", ") + s"$methodName($paramsPretty)" + case _ => + if tags.isEmpty && cfgNode.statement.tag.nonEmpty then + tags = tagAsString(cfgNode.statement.tag) + cfgNode.statement.repr + val tracked = StringUtils.normalizeSpace(StringUtils.abbreviate( + statement, + maxTrackedWidth + )) + tableRows += Array[String]( + "cfgNode", + fileLocation, + methodName, + "", + tracked, + tags + ) + end match + if isCheckLike(tags) then hasCheckLike = true + addedPaths += s"${fileName}#${lineNumber}" + } + try + if hasCheckLike then caption = s"This flow has mitigations in place.\n$caption" + printFlows(tableRows, caption) + catch + case exc: Exception => + caption - private def addEmphasis(str: String, isCheckLike: Boolean): String = - if isCheckLike then s"[green]$str[/green]" else str + private def addEmphasis(str: String, isCheckLike: Boolean): String = + if isCheckLike then s"[green]$str[/green]" else str - private def simplifyFilePath(str: String): String = - str.replace("src/main/java/", "").replace("src/main/scala/", "") + private def simplifyFilePath(str: String): String = + str.replace("src/main/java/", "").replace("src/main/scala/", "") - private def isCheckLike(tagsStr: String): Boolean = - tagsStr.contains("valid") || tagsStr.contains("encrypt") || tagsStr.contains( - "encode" - ) || tagsStr.contains( - "transform" - ) || tagsStr.contains("check") + private def isCheckLike(tagsStr: String): Boolean = + tagsStr.contains("valid") || tagsStr.contains("encrypt") || tagsStr.contains( + "encode" + ) || tagsStr.contains( + "transform" + ) || tagsStr.contains("check") - private def printFlows(tableRows: ArrayBuffer[Array[String]], caption: String): Unit = - val richTableLib = py.module("rich.table") - val richConsole = py.module("chenpy.logger").console - val table = richTableLib.Table(highlight = true, expand = true, caption = caption) - Array("Location", "Method", "Parameter", "Tracked").foreach(c => table.add_column(c)) - tableRows.foreach { row => - val end_section = row.head == "call" - val trow: Array[String] = row.tail - if !trow(3).startsWith(" table.add_column(c)) + tableRows.foreach { row => + val end_section = row.head == "call" + val trow: Array[String] = row.tail + if !trow(3).startsWith(" - srcName == node.argumentName.get - case FlowMapping(ParameterNode(srcIndex, _), _) => srcIndex == node.argumentIndex + /** Determine whether evaluation of the call this argument is a part of results in usage of this + * argument. + */ + def isUsed(implicit semantics: Semantics): Boolean = + val s = semanticsForCallByArg + s.isEmpty || s.exists(_.mappings.exists { + case FlowMapping(ParameterNode(_, Some(srcName)), _) if node.argumentName.isDefined => + srcName == node.argumentName.get + case FlowMapping(ParameterNode(srcIndex, _), _) => srcIndex == node.argumentIndex + case PassThroughMapping if node.argumentIndex != 0 => true + case _ => false + }) + + /** Determine whether evaluation of the call this argument is a part of results in definition of + * this argument. + */ + def isDefined(implicit semantics: Semantics): Boolean = + val s = semanticsForCallByArg.l + s.isEmpty || s.exists { semantic => + semantic.mappings.exists { + case FlowMapping(_, ParameterNode(_, Some(dstName))) + if node.argumentName.isDefined => + dstName == node.argumentName.get + case FlowMapping(_, ParameterNode(dstIndex, _)) => dstIndex == node.argumentIndex case PassThroughMapping if node.argumentIndex != 0 => true case _ => false - }) + } + } - /** Determine whether evaluation of the call this argument is a part of results in definition of - * this argument. - */ - def isDefined(implicit semantics: Semantics): Boolean = + /** Determines if this node and the given target node are arguments to the same call. + * @param other + * the node to compare + * @return + * true if these nodes are arguments to the same call, false if otherwise. + */ + def isArgToSameCallWith(other: Expression): Boolean = + node.astParent.start.collectAll[Call].headOption.equals( + other.astParent.start.collectAll[Call].headOption + ) + + /** Determines if this node has a flow to the given target node in the defined semantics. + * @param tgt + * the target node to check. + * @param semantics + * the pre-defined flow semantics. + * @return + * true if there is flow defined between the two nodes, false if otherwise. + */ + def hasDefinedFlowTo(tgt: Expression)(implicit semantics: Semantics): Boolean = + if tgt.evalType(".*(?i)(int|float|double|boolean).*").nonEmpty then + false + else val s = semanticsForCallByArg.l s.isEmpty || s.exists { semantic => semantic.mappings.exists { - case FlowMapping(_, ParameterNode(_, Some(dstName))) + case FlowMapping( + ParameterNode(_, Some(srcName)), + ParameterNode(_, Some(dstName)) + ) + if node.argumentName.isDefined && tgt.argumentName.isDefined => + srcName == node.argumentName.get && dstName == tgt.argumentName.get + case FlowMapping(ParameterNode(_, Some(srcName)), ParameterNode(dstIndex, _)) if node.argumentName.isDefined => - dstName == node.argumentName.get - case FlowMapping(_, ParameterNode(dstIndex, _)) => dstIndex == node.argumentIndex - case PassThroughMapping if node.argumentIndex != 0 => true - case _ => false + srcName == node.argumentName.get && dstIndex == tgt.argumentIndex + case FlowMapping(ParameterNode(srcIndex, _), ParameterNode(_, Some(dstName))) + if tgt.argumentName.isDefined => + srcIndex == node.argumentIndex && dstName == tgt.argumentName.get + case FlowMapping(ParameterNode(srcIndex, _), ParameterNode(dstIndex, _)) => + srcIndex == node.argumentIndex && dstIndex == tgt.argumentIndex + case PassThroughMapping + if tgt.argumentIndex == node.argumentIndex || tgt.argumentIndex == -1 => + true + case _ => false } } - /** Determines if this node and the given target node are arguments to the same call. - * @param other - * the node to compare - * @return - * true if these nodes are arguments to the same call, false if otherwise. - */ - def isArgToSameCallWith(other: Expression): Boolean = - node.astParent.start.collectAll[Call].headOption.equals( - other.astParent.start.collectAll[Call].headOption - ) - - /** Determines if this node has a flow to the given target node in the defined semantics. - * @param tgt - * the target node to check. - * @param semantics - * the pre-defined flow semantics. - * @return - * true if there is flow defined between the two nodes, false if otherwise. - */ - def hasDefinedFlowTo(tgt: Expression)(implicit semantics: Semantics): Boolean = - if tgt.evalType(".*(?i)(int|float|double|boolean).*").nonEmpty then - false - else - val s = semanticsForCallByArg.l - s.isEmpty || s.exists { semantic => - semantic.mappings.exists { - case FlowMapping( - ParameterNode(_, Some(srcName)), - ParameterNode(_, Some(dstName)) - ) - if node.argumentName.isDefined && tgt.argumentName.isDefined => - srcName == node.argumentName.get && dstName == tgt.argumentName.get - case FlowMapping(ParameterNode(_, Some(srcName)), ParameterNode(dstIndex, _)) - if node.argumentName.isDefined => - srcName == node.argumentName.get && dstIndex == tgt.argumentIndex - case FlowMapping(ParameterNode(srcIndex, _), ParameterNode(_, Some(dstName))) - if tgt.argumentName.isDefined => - srcIndex == node.argumentIndex && dstName == tgt.argumentName.get - case FlowMapping(ParameterNode(srcIndex, _), ParameterNode(dstIndex, _)) => - srcIndex == node.argumentIndex && dstIndex == tgt.argumentIndex - case PassThroughMapping - if tgt.argumentIndex == node.argumentIndex || tgt.argumentIndex == -1 => - true - case _ => false - } - } - - /** Retrieve flow semantic for the call this argument is a part of. - */ - def semanticsForCallByArg(implicit semantics: Semantics): Iterator[FlowSemantic] = - argToMethods(node).flatMap { method => - semantics.forMethod(method.fullName) - } + /** Retrieve flow semantic for the call this argument is a part of. + */ + def semanticsForCallByArg(implicit semantics: Semantics): Iterator[FlowSemantic] = + argToMethods(node).flatMap { method => + semantics.forMethod(method.fullName) + } - private def argToMethods(arg: Expression): Iterator[Method] = - arg.inCall.flatMap { call => - if call.nonEmpty then NoResolve.getCalledMethods(call) - else mutable.ArrayBuffer.empty[Method] - } + private def argToMethods(arg: Expression): Iterator[Method] = + arg.inCall.flatMap { call => + if call.nonEmpty then NoResolve.getCalledMethods(call) + else mutable.ArrayBuffer.empty[Method] + } end ExpressionMethods diff --git a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/language/nodemethods/ExtendedCfgNodeMethods.scala b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/language/nodemethods/ExtendedCfgNodeMethods.scala index e4d87428..a3946a12 100644 --- a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/language/nodemethods/ExtendedCfgNodeMethods.scala +++ b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/language/nodemethods/ExtendedCfgNodeMethods.scala @@ -13,84 +13,84 @@ import scala.jdk.CollectionConverters.* class ExtendedCfgNodeMethods[NodeType <: CfgNode](val node: NodeType) extends AnyVal: - /** Convert to nearest AST node - */ - def astNode: AstNode = node + /** Convert to nearest AST node + */ + def astNode: AstNode = node - def reachableBy[NodeType]( - sourceTrav: Iterator[NodeType], - sourceTravs: IterableOnce[NodeType]* - )(implicit context: EngineContext): Iterator[NodeType] = - node.start.reachableBy(sourceTrav, sourceTravs*) + def reachableBy[NodeType]( + sourceTrav: Iterator[NodeType], + sourceTravs: IterableOnce[NodeType]* + )(implicit context: EngineContext): Iterator[NodeType] = + node.start.reachableBy(sourceTrav, sourceTravs*) - def ddgIn(implicit semantics: Semantics = DefaultSemantics()): Iterator[CfgNode] = - val cache = mutable.HashMap[CfgNode, Vector[PathElement]]() - val result = ddgIn(Vector(PathElement(node)), withInvisible = false, cache) - cache.clear() - result + def ddgIn(implicit semantics: Semantics = DefaultSemantics()): Iterator[CfgNode] = + val cache = mutable.HashMap[CfgNode, Vector[PathElement]]() + val result = ddgIn(Vector(PathElement(node)), withInvisible = false, cache) + cache.clear() + result - def ddgInPathElem( - withInvisible: Boolean, - cache: mutable.HashMap[CfgNode, Vector[PathElement]] = - mutable.HashMap[CfgNode, Vector[PathElement]]() - )(implicit semantics: Semantics): Iterator[PathElement] = - ddgInPathElem(Vector(PathElement(node)), withInvisible, cache) + def ddgInPathElem( + withInvisible: Boolean, + cache: mutable.HashMap[CfgNode, Vector[PathElement]] = + mutable.HashMap[CfgNode, Vector[PathElement]]() + )(implicit semantics: Semantics): Iterator[PathElement] = + ddgInPathElem(Vector(PathElement(node)), withInvisible, cache) - def ddgInPathElem(implicit semantics: Semantics): Iterator[PathElement] = - val cache = mutable.HashMap[CfgNode, Vector[PathElement]]() - val result = ddgInPathElem(Vector(PathElement(node)), withInvisible = false, cache) - cache.clear() - result + def ddgInPathElem(implicit semantics: Semantics): Iterator[PathElement] = + val cache = mutable.HashMap[CfgNode, Vector[PathElement]]() + val result = ddgInPathElem(Vector(PathElement(node)), withInvisible = false, cache) + cache.clear() + result - /** Traverse back in the data dependence graph by one step, taking into account semantics - * @param path - * optional list of path elements that have been expanded already - */ - def ddgIn( - path: Vector[PathElement], - withInvisible: Boolean, - cache: mutable.HashMap[CfgNode, Vector[PathElement]] - )( - implicit semantics: Semantics - ): Iterator[CfgNode] = - ddgInPathElem(path, withInvisible, cache).map(_.node.asInstanceOf[CfgNode]) + /** Traverse back in the data dependence graph by one step, taking into account semantics + * @param path + * optional list of path elements that have been expanded already + */ + def ddgIn( + path: Vector[PathElement], + withInvisible: Boolean, + cache: mutable.HashMap[CfgNode, Vector[PathElement]] + )( + implicit semantics: Semantics + ): Iterator[CfgNode] = + ddgInPathElem(path, withInvisible, cache).map(_.node.asInstanceOf[CfgNode]) - /** Traverse back in the data dependence graph by one step and generate corresponding - * PathElement, taking into account semantics - * @param path - * optional list of path elements that have been expanded already - */ - def ddgInPathElem( - path: Vector[PathElement], - withInvisible: Boolean, - cache: mutable.HashMap[CfgNode, Vector[PathElement]] - )(implicit semantics: Semantics): Iterator[PathElement] = - val result = ddgInPathElemInternal(path, withInvisible, cache).iterator - result + /** Traverse back in the data dependence graph by one step and generate corresponding PathElement, + * taking into account semantics + * @param path + * optional list of path elements that have been expanded already + */ + def ddgInPathElem( + path: Vector[PathElement], + withInvisible: Boolean, + cache: mutable.HashMap[CfgNode, Vector[PathElement]] + )(implicit semantics: Semantics): Iterator[PathElement] = + val result = ddgInPathElemInternal(path, withInvisible, cache).iterator + result - private def ddgInPathElemInternal( - path: Vector[PathElement], - withInvisible: Boolean, - cache: mutable.HashMap[CfgNode, Vector[PathElement]] - )(implicit semantics: Semantics): Vector[PathElement] = + private def ddgInPathElemInternal( + path: Vector[PathElement], + withInvisible: Boolean, + cache: mutable.HashMap[CfgNode, Vector[PathElement]] + )(implicit semantics: Semantics): Vector[PathElement] = - if cache.contains(node) then - return cache(node) + if cache.contains(node) then + return cache(node) - val elems = Engine.expandIn(node, path) - val result = if withInvisible then - elems - else - (elems.filter(_.visible) ++ elems - .filterNot(_.visible) - .flatMap(x => - x.node.asInstanceOf[CfgNode].ddgInPathElem( - x +: path, - withInvisible = false, - cache - ) - )).distinct - cache.put(node, result) - result - end ddgInPathElemInternal + val elems = Engine.expandIn(node, path) + val result = if withInvisible then + elems + else + (elems.filter(_.visible) ++ elems + .filterNot(_.visible) + .flatMap(x => + x.node.asInstanceOf[CfgNode].ddgInPathElem( + x +: path, + withInvisible = false, + cache + ) + )).distinct + cache.put(node, result) + result + end ddgInPathElemInternal end ExtendedCfgNodeMethods diff --git a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/language/package.scala b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/language/package.scala index c5bac851..b79ef815 100644 --- a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/language/package.scala +++ b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/language/package.scala @@ -11,34 +11,33 @@ import scala.language.implicitConversions package object language: - implicit def cfgNodeToMethodsQp[NodeType <: CfgNode](node: NodeType) - : ExtendedCfgNodeMethods[NodeType] = - new ExtendedCfgNodeMethods(node) + implicit def cfgNodeToMethodsQp[NodeType <: CfgNode](node: NodeType) + : ExtendedCfgNodeMethods[NodeType] = + new ExtendedCfgNodeMethods(node) - implicit def expressionMethods[NodeType <: Expression](node: NodeType) - : ExpressionMethods[NodeType] = - new ExpressionMethods(node) + implicit def expressionMethods[NodeType <: Expression](node: NodeType) + : ExpressionMethods[NodeType] = + new ExpressionMethods(node) - implicit def toExtendedCfgNode[NodeType <: CfgNode](traversal: IterableOnce[NodeType]) - : ExtendedCfgNode = - new ExtendedCfgNode(traversal.iterator) + implicit def toExtendedCfgNode[NodeType <: CfgNode](traversal: IterableOnce[NodeType]) + : ExtendedCfgNode = + new ExtendedCfgNode(traversal.iterator) - implicit def toDdgNodeDot(traversal: IterableOnce[Method]): DdgNodeDot = - new DdgNodeDot(traversal.iterator) + implicit def toDdgNodeDot(traversal: IterableOnce[Method]): DdgNodeDot = + new DdgNodeDot(traversal.iterator) - implicit def toDdgNodeDotSingle(method: Method): DdgNodeDot = - new DdgNodeDot(Iterator.single(method)) + implicit def toDdgNodeDotSingle(method: Method): DdgNodeDot = + new DdgNodeDot(Iterator.single(method)) - implicit def toExtendedPathsTrav[NodeType <: Path](traversal: IterableOnce[NodeType]) - : PassesExt = - new PassesExt(traversal.iterator) + implicit def toExtendedPathsTrav[NodeType <: Path](traversal: IterableOnce[NodeType]): PassesExt = + new PassesExt(traversal.iterator) - class PassesExt(traversal: Iterator[Path]): + class PassesExt(traversal: Iterator[Path]): - def passes(trav: Iterator[AstNode] => Iterator[?]): Iterator[Path] = - traversal.filter(_.elements.exists(_.start.where(trav).nonEmpty)) + def passes(trav: Iterator[AstNode] => Iterator[?]): Iterator[Path] = + traversal.filter(_.elements.exists(_.start.where(trav).nonEmpty)) - def passesNot(trav: Iterator[AstNode] => Iterator[?]): Iterator[Path] = - traversal.filter(_.elements.forall(_.start.where(trav).isEmpty)) + def passesNot(trav: Iterator[AstNode] => Iterator[?]): Iterator[Path] = + traversal.filter(_.elements.forall(_.start.where(trav).isEmpty)) end language diff --git a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/layers/dataflows/DumpCpg14.scala b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/layers/dataflows/DumpCpg14.scala index a167d36b..6b0b2092 100644 --- a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/layers/dataflows/DumpCpg14.scala +++ b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/layers/dataflows/DumpCpg14.scala @@ -11,21 +11,21 @@ case class Cpg14DumpOptions(var outDir: String) extends LayerCreatorOptions {} object DumpCpg14: - val overlayName = "dumpCpg14" + val overlayName = "dumpCpg14" - val description = "Dump Code Property Graph (2014) to out/" + val description = "Dump Code Property Graph (2014) to out/" - def defaultOpts: Cpg14DumpOptions = Cpg14DumpOptions("out") + def defaultOpts: Cpg14DumpOptions = Cpg14DumpOptions("out") class DumpCpg14(options: Cpg14DumpOptions)(implicit semantics: Semantics = DefaultSemantics()) extends LayerCreator: - override val overlayName: String = DumpDdg.overlayName - override val description: String = DumpDdg.description - override val storeOverlayName: Boolean = false - - override def create(context: LayerCreatorContext, storeUndoInfo: Boolean): Unit = - val cpg = context.cpg - cpg.method.zipWithIndex.foreach { case (method, i) => - val str = method.dotCpg14.head - (File(options.outDir) / s"$i-cpg.dot").write(str) - } + override val overlayName: String = DumpDdg.overlayName + override val description: String = DumpDdg.description + override val storeOverlayName: Boolean = false + + override def create(context: LayerCreatorContext, storeUndoInfo: Boolean): Unit = + val cpg = context.cpg + cpg.method.zipWithIndex.foreach { case (method, i) => + val str = method.dotCpg14.head + (File(options.outDir) / s"$i-cpg.dot").write(str) + } diff --git a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/layers/dataflows/DumpDdg.scala b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/layers/dataflows/DumpDdg.scala index 2da5800a..3cfde8d5 100644 --- a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/layers/dataflows/DumpDdg.scala +++ b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/layers/dataflows/DumpDdg.scala @@ -11,21 +11,21 @@ case class DdgDumpOptions(var outDir: String) extends LayerCreatorOptions {} object DumpDdg: - val overlayName = "dumpDdg" + val overlayName = "dumpDdg" - val description = "Dump data dependence graphs to out/" + val description = "Dump data dependence graphs to out/" - def defaultOpts: DdgDumpOptions = DdgDumpOptions("out") + def defaultOpts: DdgDumpOptions = DdgDumpOptions("out") class DumpDdg(options: DdgDumpOptions)(implicit semantics: Semantics = DefaultSemantics()) extends LayerCreator: - override val overlayName: String = DumpDdg.overlayName - override val description: String = DumpDdg.description - override val storeOverlayName: Boolean = false - - override def create(context: LayerCreatorContext, storeUndoInfo: Boolean): Unit = - val cpg = context.cpg - cpg.method.zipWithIndex.foreach { case (method, i) => - val str = method.dotDdg.head - (File(options.outDir) / s"$i-ddg.dot").write(str) - } + override val overlayName: String = DumpDdg.overlayName + override val description: String = DumpDdg.description + override val storeOverlayName: Boolean = false + + override def create(context: LayerCreatorContext, storeUndoInfo: Boolean): Unit = + val cpg = context.cpg + cpg.method.zipWithIndex.foreach { case (method, i) => + val str = method.dotDdg.head + (File(options.outDir) / s"$i-ddg.dot").write(str) + } diff --git a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/layers/dataflows/DumpPdg.scala b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/layers/dataflows/DumpPdg.scala index 7b8f3aed..5ec3dd67 100644 --- a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/layers/dataflows/DumpPdg.scala +++ b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/layers/dataflows/DumpPdg.scala @@ -11,21 +11,21 @@ case class PdgDumpOptions(var outDir: String) extends LayerCreatorOptions {} object DumpPdg: - val overlayName = "dumpPdg" + val overlayName = "dumpPdg" - val description = "Dump program dependence graph to out/" + val description = "Dump program dependence graph to out/" - def defaultOpts: PdgDumpOptions = PdgDumpOptions("out") + def defaultOpts: PdgDumpOptions = PdgDumpOptions("out") class DumpPdg(options: PdgDumpOptions)(implicit semantics: Semantics = DefaultSemantics()) extends LayerCreator: - override val overlayName: String = DumpPdg.overlayName - override val description: String = DumpPdg.description - override val storeOverlayName: Boolean = false - - override def create(context: LayerCreatorContext, storeUndoInfo: Boolean): Unit = - val cpg = context.cpg - cpg.method.zipWithIndex.foreach { case (method, i) => - val str = method.dotPdg.head - (File(options.outDir) / s"$i-pdg.dot").write(str) - } + override val overlayName: String = DumpPdg.overlayName + override val description: String = DumpPdg.description + override val storeOverlayName: Boolean = false + + override def create(context: LayerCreatorContext, storeUndoInfo: Boolean): Unit = + val cpg = context.cpg + cpg.method.zipWithIndex.foreach { case (method, i) => + val str = method.dotPdg.head + (File(options.outDir) / s"$i-pdg.dot").write(str) + } diff --git a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/layers/dataflows/OssDataFlow.scala b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/layers/dataflows/OssDataFlow.scala index 5a5db4be..001fc12e 100644 --- a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/layers/dataflows/OssDataFlow.scala +++ b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/layers/dataflows/OssDataFlow.scala @@ -6,10 +6,10 @@ import io.appthreat.dataflowengineoss.semanticsloader.{FlowSemantic, Semantics} import io.shiftleft.semanticcpg.layers.{LayerCreator, LayerCreatorContext, LayerCreatorOptions} object OssDataFlow: - val overlayName: String = "dataflowOss" - val description: String = "Layer to support the OSS lightweight data flow tracker" + val overlayName: String = "dataflowOss" + val description: String = "Layer to support the OSS lightweight data flow tracker" - def defaultOpts = new OssDataFlowOptions() + def defaultOpts = new OssDataFlowOptions() class OssDataFlowOptions( var maxNumberOfDefinitions: Int = 4000, @@ -20,12 +20,12 @@ class OssDataFlow(opts: OssDataFlowOptions)(implicit s: Semantics = Semantics.fromList(DefaultSemantics().elements ++ opts.extraFlows) ) extends LayerCreator: - override val overlayName: String = OssDataFlow.overlayName - override val description: String = OssDataFlow.description + override val overlayName: String = OssDataFlow.overlayName + override val description: String = OssDataFlow.description - override def create(context: LayerCreatorContext, storeUndoInfo: Boolean): Unit = - val cpg = context.cpg - val enhancementExecList = Iterator(new ReachingDefPass(cpg, opts.maxNumberOfDefinitions)) - enhancementExecList.zipWithIndex.foreach { case (pass, index) => - runPass(pass, context, storeUndoInfo, index) - } + override def create(context: LayerCreatorContext, storeUndoInfo: Boolean): Unit = + val cpg = context.cpg + val enhancementExecList = Iterator(new ReachingDefPass(cpg, opts.maxNumberOfDefinitions)) + enhancementExecList.zipWithIndex.foreach { case (pass, index) => + runPass(pass, context, storeUndoInfo, index) + } diff --git a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/package.scala b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/package.scala index aeb69881..349c7624 100644 --- a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/package.scala +++ b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/package.scala @@ -5,16 +5,16 @@ import io.shiftleft.semanticcpg.language.* package object dataflowengineoss: - def globalFromLiteral(lit: Literal): Iterator[Expression] = lit.start - .where(_.inAssignment.method.nameExact("", ":package")) - .inAssignment - .argument(1) + def globalFromLiteral(lit: Literal): Iterator[Expression] = lit.start + .where(_.inAssignment.method.nameExact("", ":package")) + .inAssignment + .argument(1) - def identifierToFirstUsages(node: Identifier): List[Identifier] = - node.refsTo.flatMap(identifiersFromCapturedScopes).l + def identifierToFirstUsages(node: Identifier): List[Identifier] = + node.refsTo.flatMap(identifiersFromCapturedScopes).l - def identifiersFromCapturedScopes(i: Declaration): List[Identifier] = - i.capturedByMethodRef.referencedMethod.ast.isIdentifier - .nameExact(i.name) - .sortBy(x => (x.lineNumber, x.columnNumber)) - .l + def identifiersFromCapturedScopes(i: Declaration): List[Identifier] = + i.capturedByMethodRef.referencedMethod.ast.isIdentifier + .nameExact(i.name) + .sortBy(x => (x.lineNumber, x.columnNumber)) + .l diff --git a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/passes/reachingdef/DataFlowProblem.scala b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/passes/reachingdef/DataFlowProblem.scala index 207bf662..f1d69465 100644 --- a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/passes/reachingdef/DataFlowProblem.scala +++ b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/passes/reachingdef/DataFlowProblem.scala @@ -20,17 +20,17 @@ class DataFlowProblem[Node, V]( * discrepancies. */ trait FlowGraph[Node]: - val allNodesReversePostOrder: List[Node] - val allNodesPostOrder: List[Node] - def succ(node: Node): IterableOnce[Node] - def pred(node: Node): IterableOnce[Node] + val allNodesReversePostOrder: List[Node] + val allNodesPostOrder: List[Node] + def succ(node: Node): IterableOnce[Node] + def pred(node: Node): IterableOnce[Node] /** This is actually a function family consisting of one transfer function for each node of the flow * graph. Each function maps from the analysis domain to the analysis domain, e.g., for reaching * definitions, sets of definitions are mapped to sets of definitions. */ trait TransferFunction[Node, V]: - def apply(n: Node, x: V): V + def apply(n: Node, x: V): V /** As a practical optimization, OUT[N] is often initialized to GEN[N]. Moreover, we need a way of * specifying boundary conditions such as OUT[ENTRY] = {}. We achieve both by allowing the data @@ -38,9 +38,9 @@ trait TransferFunction[Node, V]: */ trait InOutInit[Node, V]: - def initIn: Map[Node, V] + def initIn: Map[Node, V] - def initOut: Map[Node, V] + def initOut: Map[Node, V] /** The solution consists of `in` and `out` for each node of the flow graph. We also attach the * problem. diff --git a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/passes/reachingdef/DataFlowSolver.scala b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/passes/reachingdef/DataFlowSolver.scala index eb656b0e..d2708d6b 100644 --- a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/passes/reachingdef/DataFlowSolver.scala +++ b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/passes/reachingdef/DataFlowSolver.scala @@ -4,73 +4,73 @@ import scala.collection.mutable class DataFlowSolver: - /** Calculate fix point solution via a standard work list algorithm (Forwards). The result is - * given by two maps: `in` and `out`. These maps associate all CFG nodes with the set of - * definitions at node entry and node exit respectively. - */ - def calculateMopSolutionForwards[Node, T <: Iterable[?]](problem: DataFlowProblem[Node, T]) - : Solution[Node, T] = - var out: Map[Node, T] = problem.inOutInit.initOut - var in = problem.inOutInit.initIn - val workList = mutable.ListBuffer[Node]() - workList ++= problem.flowGraph.allNodesReversePostOrder + /** Calculate fix point solution via a standard work list algorithm (Forwards). The result is + * given by two maps: `in` and `out`. These maps associate all CFG nodes with the set of + * definitions at node entry and node exit respectively. + */ + def calculateMopSolutionForwards[Node, T <: Iterable[?]](problem: DataFlowProblem[Node, T]) + : Solution[Node, T] = + var out: Map[Node, T] = problem.inOutInit.initOut + var in = problem.inOutInit.initIn + val workList = mutable.ListBuffer[Node]() + workList ++= problem.flowGraph.allNodesReversePostOrder - while workList.nonEmpty do - val newEntries = workList.flatMap { n => - val inSet = problem.flowGraph - .pred(n) - .iterator - .map(out) - .reduceOption((x, y) => problem.meet(x, y)) - .getOrElse(problem.empty) - in += n -> inSet - val old = out(n) - val newSet = problem.transferFunction(n, inSet) - val changed = !old.equals(newSet) - out += n -> newSet - if changed then - problem.flowGraph.succ(n) - else - List() - } - workList.clear() - workList ++= newEntries.distinct - end while - Solution(in, out, problem) - end calculateMopSolutionForwards + while workList.nonEmpty do + val newEntries = workList.flatMap { n => + val inSet = problem.flowGraph + .pred(n) + .iterator + .map(out) + .reduceOption((x, y) => problem.meet(x, y)) + .getOrElse(problem.empty) + in += n -> inSet + val old = out(n) + val newSet = problem.transferFunction(n, inSet) + val changed = !old.equals(newSet) + out += n -> newSet + if changed then + problem.flowGraph.succ(n) + else + List() + } + workList.clear() + workList ++= newEntries.distinct + end while + Solution(in, out, problem) + end calculateMopSolutionForwards - /** Calculate fix point solution via a standard work list algorithm (Backwards). The result is - * given by two maps: `in` and `out`. These maps associate all CFG nodes with the set of - * definitions at node entry and node exit respectively. - */ - def calculateMopSolutionBackwards[Node, T <: Iterable[?]](problem: DataFlowProblem[Node, T]) - : Solution[Node, T] = - var out: Map[Node, T] = problem.inOutInit.initOut - var in = problem.inOutInit.initIn - val workList = mutable.ListBuffer[Node]() - workList ++= problem.flowGraph.allNodesPostOrder + /** Calculate fix point solution via a standard work list algorithm (Backwards). The result is + * given by two maps: `in` and `out`. These maps associate all CFG nodes with the set of + * definitions at node entry and node exit respectively. + */ + def calculateMopSolutionBackwards[Node, T <: Iterable[?]](problem: DataFlowProblem[Node, T]) + : Solution[Node, T] = + var out: Map[Node, T] = problem.inOutInit.initOut + var in = problem.inOutInit.initIn + val workList = mutable.ListBuffer[Node]() + workList ++= problem.flowGraph.allNodesPostOrder - while workList.nonEmpty do - val newEntries = workList.flatMap { n => - val outSet = problem.flowGraph - .succ(n) - .iterator - .map(in) - .reduceOption((x, y) => problem.meet(x, y)) - .getOrElse(problem.empty) - out += n -> outSet - val old = in(n) - val newSet = problem.transferFunction(n, outSet) - val changed = !old.equals(newSet) - in += n -> newSet - if changed then - problem.flowGraph.pred(n) - else - List() - } - workList.clear() - workList ++= newEntries.distinct - end while - Solution(in, out, problem) - end calculateMopSolutionBackwards + while workList.nonEmpty do + val newEntries = workList.flatMap { n => + val outSet = problem.flowGraph + .succ(n) + .iterator + .map(in) + .reduceOption((x, y) => problem.meet(x, y)) + .getOrElse(problem.empty) + out += n -> outSet + val old = in(n) + val newSet = problem.transferFunction(n, outSet) + val changed = !old.equals(newSet) + in += n -> newSet + if changed then + problem.flowGraph.pred(n) + else + List() + } + workList.clear() + workList ++= newEntries.distinct + end while + Solution(in, out, problem) + end calculateMopSolutionBackwards end DataFlowSolver diff --git a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/passes/reachingdef/DdgGenerator.scala b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/passes/reachingdef/DdgGenerator.scala index 081994a4..cc74619a 100644 --- a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/passes/reachingdef/DdgGenerator.scala +++ b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/passes/reachingdef/DdgGenerator.scala @@ -16,242 +16,241 @@ import scala.collection.{Set, mutable} */ class DdgGenerator(semantics: Semantics): - implicit val s: Semantics = semantics - - /** Once reaching definitions have been computed, we create a data dependence graph by checking - * which reaching definitions are relevant, meaning that a symbol is propagated that is used by - * the target node. - * - * @param dstGraph - * the diff graph to add edges to - * @param problem - * the reaching definition problem - * @param solution - * the solution to `problem` + implicit val s: Semantics = semantics + + /** Once reaching definitions have been computed, we create a data dependence graph by checking + * which reaching definitions are relevant, meaning that a symbol is propagated that is used by + * the target node. + * + * @param dstGraph + * the diff graph to add edges to + * @param problem + * the reaching definition problem + * @param solution + * the solution to `problem` + */ + def addReachingDefEdges( + dstGraph: DiffGraphBuilder, + method: Method, + problem: DataFlowProblem[StoredNode, mutable.BitSet], + solution: Solution[StoredNode, mutable.BitSet] + ): Unit = + implicit val implicitDst: DiffGraphBuilder = dstGraph + + val numberToNode = problem.flowGraph.asInstanceOf[ReachingDefFlowGraph].numberToNode + val in = solution.in + val gen = solution.problem.transferFunction.asInstanceOf[ReachingDefTransferFunction].gen + + val allNodes = in.keys.toList + val usageAnalyzer = new UsageAnalyzer(problem, in) + + /** Add an edge from the entry node to each node that does not have other incoming definitions. */ - def addReachingDefEdges( - dstGraph: DiffGraphBuilder, - method: Method, - problem: DataFlowProblem[StoredNode, mutable.BitSet], - solution: Solution[StoredNode, mutable.BitSet] - ): Unit = - implicit val implicitDst: DiffGraphBuilder = dstGraph - - val numberToNode = problem.flowGraph.asInstanceOf[ReachingDefFlowGraph].numberToNode - val in = solution.in - val gen = solution.problem.transferFunction.asInstanceOf[ReachingDefTransferFunction].gen - - val allNodes = in.keys.toList - val usageAnalyzer = new UsageAnalyzer(problem, in) - - /** Add an edge from the entry node to each node that does not have other incoming - * definitions. - */ - def addEdgesFromEntryNode(): Unit = - // Add edges from the entry node - allNodes - .filter(n => isDdgNode(n) && usageAnalyzer.usedIncomingDefs(n).isEmpty) - .foreach { node => - addEdge(method, node) - } - - // This handles `foo(new Bar()) or return new Bar()` - def addEdgeForBlock(block: Block, towards: StoredNode): Unit = - block.astChildren.lastOption match - case None => // Do nothing - case Some(node: Identifier) => - val edgesToAdd = in(node).toList - .flatMap(numberToNode.get) - .filter(inDef => usageAnalyzer.isUsing(node, inDef)) - .collect { - case identifier: Identifier => identifier - case call: Call => call - } - edgesToAdd.foreach { inNode => - addEdge(inNode, block, nodeToEdgeLabel(inNode)) - } - if edgesToAdd.nonEmpty then - addEdge(block, towards) - case Some(node: Call) => - addEdge(node, block, nodeToEdgeLabel(node)) - addEdge(block, towards) - case _ => // Do nothing - /** Adds incoming edges to arguments of call sites, including edges between arguments of the - * same call site. - */ - def addEdgesToCallSite(call: Call): Unit = - // Edges between arguments of call sites - usageAnalyzer.usedIncomingDefs(call).foreach { case (use, ins) => - ins.foreach { in => - val inNode = numberToNode(in) - if inNode != use then - addEdge(inNode, use, nodeToEdgeLabel(inNode)) - } - } - - // For all calls, assume that input arguments - // taint corresponding output arguments - // and the return value. We filter invalid - // edges at query time (according to the given semantic). - usageAnalyzer.uses(call).foreach { use => - gen(call).foreach { g => - val genNode = numberToNode(g) - if use != genNode && isDdgNode(use) then - addEdge(use, genNode, nodeToEdgeLabel(use)) - } - } - - // This handles `foo(new Bar())`, which is lowered to - // `foo({Bar tmp = Bar.alloc(); tmp.init(); tmp})` - call.argument.isBlock.foreach { block => addEdgeForBlock(block, call) } - end addEdgesToCallSite - - def addEdgesToReturn(ret: Return): Unit = - // This handles `return new Bar()`, which is lowered to - // `return {Bar tmp = Bar.alloc(); tmp.init(); tmp}` - usageAnalyzer.uses(ret).collectAll[Block].foreach(block => addEdgeForBlock(block, ret)) - usageAnalyzer.usedIncomingDefs(ret).foreach { case (use: CfgNode, inElements) => - addEdge(use, ret, use.code) - inElements - .filterNot(x => numberToNode.get(x).contains(use)) - .flatMap(numberToNode.get) - .foreach { inElemNode => - addEdge(inElemNode, use, nodeToEdgeLabel(inElemNode)) - } - if inElements.isEmpty then - addEdge(method, ret) - } - addEdge(ret, method.methodReturn, "") - - def addEdgesToMethodParameterOut(paramOut: MethodParameterOut): Unit = - // There is always an edge from the method input parameter - // to the corresponding method output parameter as modifications - // of the input parameter only affect a copy. - paramOut.paramIn.foreach { paramIn => - addEdge(paramIn, paramOut, paramIn.name) - } - usageAnalyzer.usedIncomingDefs(paramOut).foreach { case (_, inElements) => - inElements.foreach { inElement => - val inElemNode = numberToNode(inElement) - val edgeLabel = nodeToEdgeLabel(inElemNode) - addEdge(inElemNode, paramOut, edgeLabel) - } - } - - def addEdgesToExitNode(exitNode: MethodReturn): Unit = - in(exitNode).foreach { i => - val iNode = numberToNode(i) - addEdge(iNode, exitNode, nodeToEdgeLabel(iNode)) - } - - /** This is part of the Lone-identifier optimization: as we remove lone identifiers from - * `gen` sets, we must now retrieve them and create an edge from each lone identifier to - * the exit node. - */ - def addEdgesFromLoneIdentifiersToExit(method: Method): Unit = - val numberToNode = problem.flowGraph.asInstanceOf[ReachingDefFlowGraph].numberToNode - val exitNode = method.methodReturn - val transferFunction = - solution.problem.transferFunction.asInstanceOf[OptimizedReachingDefTransferFunction] - val genOnce = transferFunction.loneIdentifiers - genOnce.foreach { case (_, defs) => - defs.foreach { d => - val dNode = numberToNode(d) - addEdge(dNode, exitNode, nodeToEdgeLabel(dNode)) - } - } - - def addEdgesToCapturedIdentifiersAndParameters(): Unit = - val identifierDestPairs = - method._identifierViaContainsOut.flatMap { identifier => - val firstAndLastUsageByMethod = - identifierToFirstUsages(identifier).groupBy(_.method) - firstAndLastUsageByMethod.values - .filter(_.nonEmpty) - .map(x => (x.head, x.last)) - .flatMap { case (firstUsage, lastUsage) => - ( - identifier.lineNumber, - firstUsage.lineNumber, - lastUsage.lineNumber - ) match - case (Some(iNo), Some(fNo), _) if iNo <= fNo => - Some(identifier, firstUsage) - case (Some(iNo), _, Some(lNo)) if iNo >= lNo => - Some(lastUsage, identifier) - case _ => None - } - }.distinct - - identifierDestPairs - .foreach { case (src, dst) => - addEdge(src, dst, nodeToEdgeLabel(src)) - } - method.parameter.foreach { param => - param.capturedByMethodRef.referencedMethod.ast.isIdentifier.foreach { identifier => - addEdge(param, identifier, nodeToEdgeLabel(param)) - } + def addEdgesFromEntryNode(): Unit = + // Add edges from the entry node + allNodes + .filter(n => isDdgNode(n) && usageAnalyzer.usedIncomingDefs(n).isEmpty) + .foreach { node => + addEdge(method, node) } - val globalIdentifiers = - method.ast.isLiteral.flatMap(globalFromLiteral).collectAll[Identifier].l - globalIdentifiers - .foreach { global => - identifierToFirstUsages(global).map { identifier => - addEdge(global, identifier, nodeToEdgeLabel(global)) - } - } - end addEdgesToCapturedIdentifiersAndParameters - - addEdgesFromEntryNode() - allNodes.foreach { - case call: Call => addEdgesToCallSite(call) - case ret: Return => addEdgesToReturn(ret) - case paramOut: MethodParameterOut => addEdgesToMethodParameterOut(paramOut) - case _ => + // This handles `foo(new Bar()) or return new Bar()` + def addEdgeForBlock(block: Block, towards: StoredNode): Unit = + block.astChildren.lastOption match + case None => // Do nothing + case Some(node: Identifier) => + val edgesToAdd = in(node).toList + .flatMap(numberToNode.get) + .filter(inDef => usageAnalyzer.isUsing(node, inDef)) + .collect { + case identifier: Identifier => identifier + case call: Call => call + } + edgesToAdd.foreach { inNode => + addEdge(inNode, block, nodeToEdgeLabel(inNode)) + } + if edgesToAdd.nonEmpty then + addEdge(block, towards) + case Some(node: Call) => + addEdge(node, block, nodeToEdgeLabel(node)) + addEdge(block, towards) + case _ => // Do nothing + /** Adds incoming edges to arguments of call sites, including edges between arguments of the + * same call site. + */ + def addEdgesToCallSite(call: Call): Unit = + // Edges between arguments of call sites + usageAnalyzer.usedIncomingDefs(call).foreach { case (use, ins) => + ins.foreach { in => + val inNode = numberToNode(in) + if inNode != use then + addEdge(inNode, use, nodeToEdgeLabel(inNode)) + } + } + + // For all calls, assume that input arguments + // taint corresponding output arguments + // and the return value. We filter invalid + // edges at query time (according to the given semantic). + usageAnalyzer.uses(call).foreach { use => + gen(call).foreach { g => + val genNode = numberToNode(g) + if use != genNode && isDdgNode(use) then + addEdge(use, genNode, nodeToEdgeLabel(use)) + } + } + + // This handles `foo(new Bar())`, which is lowered to + // `foo({Bar tmp = Bar.alloc(); tmp.init(); tmp})` + call.argument.isBlock.foreach { block => addEdgeForBlock(block, call) } + end addEdgesToCallSite + + def addEdgesToReturn(ret: Return): Unit = + // This handles `return new Bar()`, which is lowered to + // `return {Bar tmp = Bar.alloc(); tmp.init(); tmp}` + usageAnalyzer.uses(ret).collectAll[Block].foreach(block => addEdgeForBlock(block, ret)) + usageAnalyzer.usedIncomingDefs(ret).foreach { case (use: CfgNode, inElements) => + addEdge(use, ret, use.code) + inElements + .filterNot(x => numberToNode.get(x).contains(use)) + .flatMap(numberToNode.get) + .foreach { inElemNode => + addEdge(inElemNode, use, nodeToEdgeLabel(inElemNode)) + } + if inElements.isEmpty then + addEdge(method, ret) + } + addEdge(ret, method.methodReturn, "") + + def addEdgesToMethodParameterOut(paramOut: MethodParameterOut): Unit = + // There is always an edge from the method input parameter + // to the corresponding method output parameter as modifications + // of the input parameter only affect a copy. + paramOut.paramIn.foreach { paramIn => + addEdge(paramIn, paramOut, paramIn.name) + } + usageAnalyzer.usedIncomingDefs(paramOut).foreach { case (_, inElements) => + inElements.foreach { inElement => + val inElemNode = numberToNode(inElement) + val edgeLabel = nodeToEdgeLabel(inElemNode) + addEdge(inElemNode, paramOut, edgeLabel) + } + } + + def addEdgesToExitNode(exitNode: MethodReturn): Unit = + in(exitNode).foreach { i => + val iNode = numberToNode(i) + addEdge(iNode, exitNode, nodeToEdgeLabel(iNode)) } - addEdgesToCapturedIdentifiersAndParameters() - addEdgesToExitNode(method.methodReturn) - addEdgesFromLoneIdentifiersToExit(method) - end addReachingDefEdges - - private def addEdge(fromNode: StoredNode, toNode: StoredNode, variable: String = "")(implicit - dstGraph: DiffGraphBuilder - ): Unit = - if fromNode.isInstanceOf[Unknown] || toNode.isInstanceOf[Unknown] then - return - - (fromNode, toNode) match - case (parentNode: CfgNode, childNode: CfgNode) - if EdgeValidator.isValidEdge(childNode, parentNode) => - dstGraph.addEdge( - fromNode, - toNode, - EdgeTypes.REACHING_DEF, - PropertyNames.VARIABLE, - variable - ) - case _ => - - /** There are a few node types that (a) are not to be considered in the DDG, or (b) are not - * standalone DDG nodes, or (c) have a special meaning in the DDG. This function indicates - * whether the given node is just a regular DDG node instead. + /** This is part of the Lone-identifier optimization: as we remove lone identifiers from `gen` + * sets, we must now retrieve them and create an edge from each lone identifier to the exit + * node. */ - private def isDdgNode(x: StoredNode): Boolean = - x match - case _: Method => false - case _: ControlStructure => false - case _: FieldIdentifier => false - case _: JumpTarget => false - case _: MethodReturn => false - case _ => true - - private def nodeToEdgeLabel(node: StoredNode): String = - node match - case n: MethodParameterIn => n.name - case n: CfgNode => n.code - case _ => "" + def addEdgesFromLoneIdentifiersToExit(method: Method): Unit = + val numberToNode = problem.flowGraph.asInstanceOf[ReachingDefFlowGraph].numberToNode + val exitNode = method.methodReturn + val transferFunction = + solution.problem.transferFunction.asInstanceOf[OptimizedReachingDefTransferFunction] + val genOnce = transferFunction.loneIdentifiers + genOnce.foreach { case (_, defs) => + defs.foreach { d => + val dNode = numberToNode(d) + addEdge(dNode, exitNode, nodeToEdgeLabel(dNode)) + } + } + + def addEdgesToCapturedIdentifiersAndParameters(): Unit = + val identifierDestPairs = + method._identifierViaContainsOut.flatMap { identifier => + val firstAndLastUsageByMethod = + identifierToFirstUsages(identifier).groupBy(_.method) + firstAndLastUsageByMethod.values + .filter(_.nonEmpty) + .map(x => (x.head, x.last)) + .flatMap { case (firstUsage, lastUsage) => + ( + identifier.lineNumber, + firstUsage.lineNumber, + lastUsage.lineNumber + ) match + case (Some(iNo), Some(fNo), _) if iNo <= fNo => + Some(identifier, firstUsage) + case (Some(iNo), _, Some(lNo)) if iNo >= lNo => + Some(lastUsage, identifier) + case _ => None + } + }.distinct + + identifierDestPairs + .foreach { case (src, dst) => + addEdge(src, dst, nodeToEdgeLabel(src)) + } + method.parameter.foreach { param => + param.capturedByMethodRef.referencedMethod.ast.isIdentifier.foreach { identifier => + addEdge(param, identifier, nodeToEdgeLabel(param)) + } + } + + val globalIdentifiers = + method.ast.isLiteral.flatMap(globalFromLiteral).collectAll[Identifier].l + globalIdentifiers + .foreach { global => + identifierToFirstUsages(global).map { identifier => + addEdge(global, identifier, nodeToEdgeLabel(global)) + } + } + end addEdgesToCapturedIdentifiersAndParameters + + addEdgesFromEntryNode() + allNodes.foreach { + case call: Call => addEdgesToCallSite(call) + case ret: Return => addEdgesToReturn(ret) + case paramOut: MethodParameterOut => addEdgesToMethodParameterOut(paramOut) + case _ => + } + + addEdgesToCapturedIdentifiersAndParameters() + addEdgesToExitNode(method.methodReturn) + addEdgesFromLoneIdentifiersToExit(method) + end addReachingDefEdges + + private def addEdge(fromNode: StoredNode, toNode: StoredNode, variable: String = "")(implicit + dstGraph: DiffGraphBuilder + ): Unit = + if fromNode.isInstanceOf[Unknown] || toNode.isInstanceOf[Unknown] then + return + + (fromNode, toNode) match + case (parentNode: CfgNode, childNode: CfgNode) + if EdgeValidator.isValidEdge(childNode, parentNode) => + dstGraph.addEdge( + fromNode, + toNode, + EdgeTypes.REACHING_DEF, + PropertyNames.VARIABLE, + variable + ) + case _ => + + /** There are a few node types that (a) are not to be considered in the DDG, or (b) are not + * standalone DDG nodes, or (c) have a special meaning in the DDG. This function indicates + * whether the given node is just a regular DDG node instead. + */ + private def isDdgNode(x: StoredNode): Boolean = + x match + case _: Method => false + case _: ControlStructure => false + case _: FieldIdentifier => false + case _: JumpTarget => false + case _: MethodReturn => false + case _ => true + + private def nodeToEdgeLabel(node: StoredNode): String = + node match + case n: MethodParameterIn => n.name + case n: CfgNode => n.code + case _ => "" end DdgGenerator /** Upon calculating reaching definitions, we find ourselves with a set of incoming definitions @@ -263,111 +262,111 @@ private class UsageAnalyzer( in: Map[StoredNode, Set[Definition]] ): - val numberToNode: Map[Definition, StoredNode] = - problem.flowGraph.asInstanceOf[ReachingDefFlowGraph].numberToNode - - private val allNodes = in.keys.toList - private val containerSet = - Set( - Operators.fieldAccess, - Operators.indexAccess, - Operators.indirectIndexAccess, - Operators.indirectFieldAccess - ) - private val indirectionAccessSet = Set(Operators.addressOf, Operators.indirection) - val usedIncomingDefs: Map[StoredNode, Map[StoredNode, Set[Definition]]] = initUsedIncomingDefs() - - def initUsedIncomingDefs(): Map[StoredNode, Map[StoredNode, Set[Definition]]] = - allNodes.map { node => - node -> usedIncomingDefsForNode(node) - }.toMap - - private def usedIncomingDefsForNode(node: StoredNode): Map[StoredNode, Set[Definition]] = - uses(node).map { use => - use -> in(node).filter { inElement => - val inElemNode = numberToNode(inElement) - isUsing(use, inElemNode) + val numberToNode: Map[Definition, StoredNode] = + problem.flowGraph.asInstanceOf[ReachingDefFlowGraph].numberToNode + + private val allNodes = in.keys.toList + private val containerSet = + Set( + Operators.fieldAccess, + Operators.indexAccess, + Operators.indirectIndexAccess, + Operators.indirectFieldAccess + ) + private val indirectionAccessSet = Set(Operators.addressOf, Operators.indirection) + val usedIncomingDefs: Map[StoredNode, Map[StoredNode, Set[Definition]]] = initUsedIncomingDefs() + + def initUsedIncomingDefs(): Map[StoredNode, Map[StoredNode, Set[Definition]]] = + allNodes.map { node => + node -> usedIncomingDefsForNode(node) + }.toMap + + private def usedIncomingDefsForNode(node: StoredNode): Map[StoredNode, Set[Definition]] = + uses(node).map { use => + use -> in(node).filter { inElement => + val inElemNode = numberToNode(inElement) + isUsing(use, inElemNode) + } + }.toMap + + def isUsing(use: StoredNode, inElemNode: StoredNode): Boolean = + sameVariable(use, inElemNode) || isContainer(use, inElemNode) || isPart( + use, + inElemNode + ) || isAlias(use, inElemNode) + + /** Determine whether the node `use` describes a container for `inElement`, e.g., use = `ptr` + * while inElement = `ptr->foo`. + */ + private def isContainer(use: StoredNode, inElement: StoredNode): Boolean = + inElement match + case call: Call if containerSet.contains(call.name) => + call.argument.headOption.exists { base => + nodeToString(use) == nodeToString(base) } - }.toMap - - def isUsing(use: StoredNode, inElemNode: StoredNode): Boolean = - sameVariable(use, inElemNode) || isContainer(use, inElemNode) || isPart( - use, - inElemNode - ) || isAlias(use, inElemNode) - - /** Determine whether the node `use` describes a container for `inElement`, e.g., use = `ptr` - * while inElement = `ptr->foo`. - */ - private def isContainer(use: StoredNode, inElement: StoredNode): Boolean = - inElement match - case call: Call if containerSet.contains(call.name) => - call.argument.headOption.exists { base => - nodeToString(use) == nodeToString(base) - } - case _ => false - - /** Determine whether `use` is a part of `inElement`, e.g., use = `argv[0]` while inElement = - * `argv` - */ - private def isPart(use: StoredNode, inElement: StoredNode): Boolean = - use match - case call: Call if containerSet.contains(call.name) => - inElement match - case param: MethodParameterIn => - call.argument.headOption.exists { base => - nodeToString(base).contains(param.name) - } - case identifier: Identifier => - call.argument.headOption.exists { base => - nodeToString(base).contains(identifier.name) - } - case _ => false - case _ => false - - private def isAlias(use: StoredNode, inElement: StoredNode): Boolean = - use match - case useCall: Call => - inElement match - case inCall: Call => - val (useBase, useAccessPath) = toTrackedBaseAndAccessPathSimple(useCall) - val (inBase, inAccessPath) = toTrackedBaseAndAccessPathSimple(inCall) - useBase == inBase && useAccessPath.matchAndDiff( - inAccessPath.elements - )._1 == MatchResult.EXACT_MATCH - case _ => false - case _ => false - - def uses(node: StoredNode): Set[StoredNode] = - val n: Set[StoredNode] = node match - case ret: Return => ret.astChildren.collect { case x: Expression => x }.toSet - case call: Call => call.argument.toSet - case paramOut: MethodParameterOut => Set(paramOut) - case _ => Set() - n.filterNot(_.isInstanceOf[FieldIdentifier]) - - /** Compares arguments of calls with incoming definitions to see if they refer to the same - * variable - */ - private def sameVariable(use: StoredNode, inElement: StoredNode): Boolean = - inElement match - case param: MethodParameterIn => - nodeToString(use).contains(param.name) || nodeToString(use).contains(param.code) - case call: Call if indirectionAccessSet.contains(call.name) => - call.argumentOption(1).exists(x => nodeToString(use).contains(x.code)) - case call: Call => - nodeToString(use).contains(call.code) - case identifier: Identifier => - nodeToString(use).contains(identifier.name) || nodeToString(use).contains( - identifier.code - ) - case _ => false - - private def nodeToString(node: StoredNode): Option[String] = - node match - case ident: Identifier => Some(ident.name) - case exp: Expression => Some(exp.code) - case p: MethodParameterIn => Some(p.name) - case p: MethodParameterOut => Some(p.name) - case _ => None + case _ => false + + /** Determine whether `use` is a part of `inElement`, e.g., use = `argv[0]` while inElement = + * `argv` + */ + private def isPart(use: StoredNode, inElement: StoredNode): Boolean = + use match + case call: Call if containerSet.contains(call.name) => + inElement match + case param: MethodParameterIn => + call.argument.headOption.exists { base => + nodeToString(base).contains(param.name) + } + case identifier: Identifier => + call.argument.headOption.exists { base => + nodeToString(base).contains(identifier.name) + } + case _ => false + case _ => false + + private def isAlias(use: StoredNode, inElement: StoredNode): Boolean = + use match + case useCall: Call => + inElement match + case inCall: Call => + val (useBase, useAccessPath) = toTrackedBaseAndAccessPathSimple(useCall) + val (inBase, inAccessPath) = toTrackedBaseAndAccessPathSimple(inCall) + useBase == inBase && useAccessPath.matchAndDiff( + inAccessPath.elements + )._1 == MatchResult.EXACT_MATCH + case _ => false + case _ => false + + def uses(node: StoredNode): Set[StoredNode] = + val n: Set[StoredNode] = node match + case ret: Return => ret.astChildren.collect { case x: Expression => x }.toSet + case call: Call => call.argument.toSet + case paramOut: MethodParameterOut => Set(paramOut) + case _ => Set() + n.filterNot(_.isInstanceOf[FieldIdentifier]) + + /** Compares arguments of calls with incoming definitions to see if they refer to the same + * variable + */ + private def sameVariable(use: StoredNode, inElement: StoredNode): Boolean = + inElement match + case param: MethodParameterIn => + nodeToString(use).contains(param.name) || nodeToString(use).contains(param.code) + case call: Call if indirectionAccessSet.contains(call.name) => + call.argumentOption(1).exists(x => nodeToString(use).contains(x.code)) + case call: Call => + nodeToString(use).contains(call.code) + case identifier: Identifier => + nodeToString(use).contains(identifier.name) || nodeToString(use).contains( + identifier.code + ) + case _ => false + + private def nodeToString(node: StoredNode): Option[String] = + node match + case ident: Identifier => Some(ident.name) + case exp: Expression => Some(exp.code) + case p: MethodParameterIn => Some(p.name) + case p: MethodParameterOut => Some(p.name) + case _ => None end UsageAnalyzer diff --git a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/passes/reachingdef/EdgeValidator.scala b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/passes/reachingdef/EdgeValidator.scala index 2ab75eda..69df65ee 100644 --- a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/passes/reachingdef/EdgeValidator.scala +++ b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/passes/reachingdef/EdgeValidator.scala @@ -14,51 +14,49 @@ import io.shiftleft.semanticcpg.language.* object EdgeValidator: - /** Determines whether the edge from `parentNode`to `childNode` is valid, according to the given - * semantics. - */ - def isValidEdge(childNode: CfgNode, parentNode: CfgNode)(implicit - semantics: Semantics - ): Boolean = - (childNode, parentNode) match - case (childNode: Expression, parentNode) - if isCallRetval(parentNode) || !isValidEdgeToExpression(parentNode, childNode) => - false - case (childNode: Call, parentNode: Expression) - if isCallRetval(childNode) && childNode.argument.contains(parentNode) => - // e.g. foo(x), but there are semantics for `foo` that don't taint its return value - // in which case we don't want `x` to taint `foo(x)`. - false - case (childNode: Expression, parentNode: Expression) - if parentNode.isArgToSameCallWith( - childNode - ) && childNode.isDefined && parentNode.isUsed => - parentNode.hasDefinedFlowTo(childNode) - case (_: Expression, _: Expression) => true - case (childNode: Expression, _) if !childNode.isUsed => false - case (_: Expression, _) => true - case (_, parentNode) => !isCallRetval(parentNode) + /** Determines whether the edge from `parentNode`to `childNode` is valid, according to the given + * semantics. + */ + def isValidEdge(childNode: CfgNode, parentNode: CfgNode)(implicit semantics: Semantics): Boolean = + (childNode, parentNode) match + case (childNode: Expression, parentNode) + if isCallRetval(parentNode) || !isValidEdgeToExpression(parentNode, childNode) => + false + case (childNode: Call, parentNode: Expression) + if isCallRetval(childNode) && childNode.argument.contains(parentNode) => + // e.g. foo(x), but there are semantics for `foo` that don't taint its return value + // in which case we don't want `x` to taint `foo(x)`. + false + case (childNode: Expression, parentNode: Expression) + if parentNode.isArgToSameCallWith( + childNode + ) && childNode.isDefined && parentNode.isUsed => + parentNode.hasDefinedFlowTo(childNode) + case (_: Expression, _: Expression) => true + case (childNode: Expression, _) if !childNode.isUsed => false + case (_: Expression, _) => true + case (_, parentNode) => !isCallRetval(parentNode) - private def isValidEdgeToExpression(parNode: CfgNode, curNode: Expression)(implicit - semantics: Semantics - ): Boolean = - parNode match - case parentNode: Expression => - val sameCallSite = parentNode.inCall.l == curNode.start.inCall.l - !(sameCallSite && isOutputArgOfInternalMethod(parentNode)) && - (sameCallSite && parentNode.isUsed && curNode.isDefined || !sameCallSite && curNode.isUsed) - case _ => - curNode.isUsed + private def isValidEdgeToExpression(parNode: CfgNode, curNode: Expression)(implicit + semantics: Semantics + ): Boolean = + parNode match + case parentNode: Expression => + val sameCallSite = parentNode.inCall.l == curNode.start.inCall.l + !(sameCallSite && isOutputArgOfInternalMethod(parentNode)) && + (sameCallSite && parentNode.isUsed && curNode.isDefined || !sameCallSite && curNode.isUsed) + case _ => + curNode.isUsed - private def isCallRetval(parentNode: StoredNode)(implicit semantics: Semantics): Boolean = - parentNode match - case call: Call => - val sem = semantics.forMethod(call.methodFullName) - sem.isDefined && !sem.get.mappings.exists { - case FlowMapping(_, ParameterNode(dst, _)) => dst == -1 - case PassThroughMapping => true - case _ => false - } - case _ => - false + private def isCallRetval(parentNode: StoredNode)(implicit semantics: Semantics): Boolean = + parentNode match + case call: Call => + val sem = semantics.forMethod(call.methodFullName) + sem.isDefined && !sem.get.mappings.exists { + case FlowMapping(_, ParameterNode(dst, _)) => dst == -1 + case PassThroughMapping => true + case _ => false + } + case _ => + false end EdgeValidator diff --git a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/passes/reachingdef/ReachingDefPass.scala b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/passes/reachingdef/ReachingDefPass.scala index 83f6f75c..4a865cfb 100755 --- a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/passes/reachingdef/ReachingDefPass.scala +++ b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/passes/reachingdef/ReachingDefPass.scala @@ -15,44 +15,44 @@ import scala.collection.mutable class ReachingDefPass(cpg: Cpg, maxNumberOfDefinitions: Int = 4000)(implicit s: Semantics) extends ForkJoinParallelCpgPass[Method](cpg): - private val logger: Logger = LoggerFactory.getLogger(this.getClass) - // If there are any regex method full names, load them early - s.loadRegexSemantics(cpg) - - override def generateParts(): Array[Method] = cpg.method.toArray - - override def runOnPart(dstGraph: DiffGraphBuilder, method: Method): Unit = - logger.debug( - "Calculating reaching definitions for: {} in {}", - method.fullName, - method.filename - ) - val problem = ReachingDefProblem.create(method) - if shouldBailOut(method, problem) then - logger.warn("Skipping.") - return - - val solution = new DataFlowSolver().calculateMopSolutionForwards(problem) - val ddgGenerator = new DdgGenerator(s) - ddgGenerator.addReachingDefEdges(dstGraph, method, problem, solution) - - /** Before we start propagating definitions in the graph, which is the bulk of the work, we - * check how many definitions were are dealing with in total. If a threshold is reached, we - * bail out instead, leaving reaching definitions uncalculated for the method in question. - * Users can increase the threshold if desired. - */ - private def shouldBailOut( - method: Method, - problem: DataFlowProblem[StoredNode, mutable.BitSet] - ): Boolean = - val transferFunction = problem.transferFunction.asInstanceOf[ReachingDefTransferFunction] - // For each node, the `gen` map contains the list of definitions it generates - // We add up the sizes of these lists to obtain the total number of definitions - val numberOfDefinitions = transferFunction.gen.foldLeft(0)(_ + _._2.size) - logger.debug("Number of definitions for {}: {}", method.fullName, numberOfDefinitions) - if numberOfDefinitions > maxNumberOfDefinitions then - logger.warn("{} has more than {} definitions", method.fullName, maxNumberOfDefinitions) - true - else - false + private val logger: Logger = LoggerFactory.getLogger(this.getClass) + // If there are any regex method full names, load them early + s.loadRegexSemantics(cpg) + + override def generateParts(): Array[Method] = cpg.method.toArray + + override def runOnPart(dstGraph: DiffGraphBuilder, method: Method): Unit = + logger.debug( + "Calculating reaching definitions for: {} in {}", + method.fullName, + method.filename + ) + val problem = ReachingDefProblem.create(method) + if shouldBailOut(method, problem) then + logger.warn("Skipping.") + return + + val solution = new DataFlowSolver().calculateMopSolutionForwards(problem) + val ddgGenerator = new DdgGenerator(s) + ddgGenerator.addReachingDefEdges(dstGraph, method, problem, solution) + + /** Before we start propagating definitions in the graph, which is the bulk of the work, we check + * how many definitions were are dealing with in total. If a threshold is reached, we bail out + * instead, leaving reaching definitions uncalculated for the method in question. Users can + * increase the threshold if desired. + */ + private def shouldBailOut( + method: Method, + problem: DataFlowProblem[StoredNode, mutable.BitSet] + ): Boolean = + val transferFunction = problem.transferFunction.asInstanceOf[ReachingDefTransferFunction] + // For each node, the `gen` map contains the list of definitions it generates + // We add up the sizes of these lists to obtain the total number of definitions + val numberOfDefinitions = transferFunction.gen.foldLeft(0)(_ + _._2.size) + logger.debug("Number of definitions for {}: {}", method.fullName, numberOfDefinitions) + if numberOfDefinitions > maxNumberOfDefinitions then + logger.warn("{} has more than {} definitions", method.fullName, maxNumberOfDefinitions) + true + else + false end ReachingDefPass diff --git a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/passes/reachingdef/ReachingDefProblem.scala b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/passes/reachingdef/ReachingDefProblem.scala index f6b6a7f3..4d750549 100644 --- a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/passes/reachingdef/ReachingDefProblem.scala +++ b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/passes/reachingdef/ReachingDefProblem.scala @@ -15,141 +15,140 @@ import scala.collection.{Set, mutable} * readability. */ object Definition: - def fromNode(node: StoredNode, nodeToNumber: Map[StoredNode, Int]): Definition = - nodeToNumber(node) + def fromNode(node: StoredNode, nodeToNumber: Map[StoredNode, Int]): Definition = + nodeToNumber(node) object ReachingDefProblem: - def create(method: Method): DataFlowProblem[StoredNode, mutable.BitSet] = - val flowGraph = new ReachingDefFlowGraph(method) - val transfer = new OptimizedReachingDefTransferFunction(flowGraph) - val init = new ReachingDefInit(transfer.gen) - def meet: (mutable.BitSet, mutable.BitSet) => mutable.BitSet = - (x: mutable.BitSet, y: mutable.BitSet) => x.union(y) - - new DataFlowProblem[StoredNode, mutable.BitSet]( - flowGraph, - transfer, - meet, - init, - true, - mutable.BitSet() - ) + def create(method: Method): DataFlowProblem[StoredNode, mutable.BitSet] = + val flowGraph = new ReachingDefFlowGraph(method) + val transfer = new OptimizedReachingDefTransferFunction(flowGraph) + val init = new ReachingDefInit(transfer.gen) + def meet: (mutable.BitSet, mutable.BitSet) => mutable.BitSet = + (x: mutable.BitSet, y: mutable.BitSet) => x.union(y) + + new DataFlowProblem[StoredNode, mutable.BitSet]( + flowGraph, + transfer, + meet, + init, + true, + mutable.BitSet() + ) /** The control flow graph as viewed by the data flow solver. */ class ReachingDefFlowGraph(val method: Method) extends FlowGraph[StoredNode]: - private val logger: Logger = LoggerFactory.getLogger(this.getClass) - - val entryNode: StoredNode = method - val exitNode: StoredNode = method.methodReturn - - private val params = method.parameter.sortBy(_.index) - private val firstParam = params.headOption - private val lastParam = params.lastOption - private val firstOutputParam = firstParam.flatMap(_.asOutput.headOption) - private val lastOutputParam = method.parameter.sortBy(_.index).asOutput.lastOption - - private val lastActualCfgNode = exitNode._cfgIn.nextOption() - - val allNodesReversePostOrder: List[StoredNode] = - List(entryNode) ++ method.parameter.toList ++ method.reversePostOrder.toList.filter(x => - x.id != entryNode.id && x.id != exitNode.id - ) ++ method.parameter.asOutput.toList ++ List(exitNode) - - private val allNodesEvenUnreachable = - allNodesReversePostOrder ++ method.cfgNode.l.filterNot(x => - allNodesReversePostOrder.contains(x) - ) - val nodeToNumber: Map[StoredNode, Int] = - allNodesEvenUnreachable.zipWithIndex.map { case (x, i) => x -> i }.toMap - val numberToNode: Map[Int, StoredNode] = - allNodesEvenUnreachable.zipWithIndex.map { case (x, i) => i -> x }.toMap - - val allNodesPostOrder: List[StoredNode] = allNodesReversePostOrder.reverse - - private val _succ: Map[StoredNode, List[StoredNode]] = initSucc(allNodesReversePostOrder) - private val _pred: Map[StoredNode, List[StoredNode]] = - initPred(allNodesReversePostOrder, method) - - override def succ(node: StoredNode): IterableOnce[StoredNode] = - _succ.apply(node) - - override def pred(node: StoredNode): IterableOnce[StoredNode] = - _pred.apply(node) - - /** Create a map that allows CFG successors to be retrieved for each node - */ - private def initSucc(ns: List[StoredNode]): Map[StoredNode, List[StoredNode]] = - ns.map { - case n: Method => n -> firstParamOrBody(n) - case ret: Return => ret -> List(firstOutputParam.getOrElse(exitNode)) - case param: MethodParameterIn => - param -> nextParamOrBody(param) - case paramOut: MethodParameterOut => paramOut -> nextParamOutOrExit(paramOut) - case cfgNode: CfgNode => cfgNode -> cfgNextOrFirstOutParam(cfgNode) - case n => - logger.warn(s"Node type ${n.getClass.getSimpleName} should not be part of the CFG") - n -> List() - }.toMap - - /** Create a map that allows CFG predecessors to be retrieved for each node - */ - private def initPred(ns: List[StoredNode], method: Method): Map[StoredNode, List[StoredNode]] = - ns.map { - case param: MethodParameterIn => param -> previousParamOrEntry(param) - case paramOut: MethodParameterOut => - paramOut -> previousOutputParamOrLastNodeOfBody(paramOut) - case n: CfgNode if method.cfgFirst.headOption.contains(n) => - n -> List(lastParam.getOrElse(method)) - case n if n == exitNode => n -> lastOutputParamOrLastNodeOfBody() - case n @ (cfgNode: CfgNode) => n -> cfgNode.cfgPrev.l - case n => - logger.warn(s"Node type ${n.getClass.getSimpleName} should not be part of the CFG") - n -> List() - }.toMap - - private def firstParamOrBody(n: Method): List[StoredNode] = - if firstParam.isDefined then - firstParam.toList - else - cfgNext(n) - - private def cfgNext(n: CfgNode): List[StoredNode] = - n.out(EdgeTypes.CFG).map(_.asInstanceOf[StoredNode]).l - - private def nextParamOrBody(param: MethodParameterIn): List[StoredNode] = - val nextParam = param.method.parameter.index(param.index + 1).headOption - if nextParam.isDefined then nextParam.toList - else param.method.cfgFirst.l - - private def nextParamOutOrExit(paramOut: MethodParameterOut): List[StoredNode] = - val nextParam = paramOut.method.parameter.index(paramOut.index + 1).asOutput.headOption - if nextParam.isDefined then nextParam.toList - else List(exitNode) - - private def cfgNextOrFirstOutParam(cfgNode: CfgNode): List[StoredNode] = - // `.cfgNext` would be wrong here because it filters `METHOD_RETURN` - val successors = cfgNode.out(EdgeTypes.CFG).map(_.asInstanceOf[StoredNode]).l - if successors == List(exitNode) && firstOutputParam.isDefined then - List(firstOutputParam.get) - else - successors - - private def previousParamOrEntry(param: MethodParameterIn): List[StoredNode] = - val prevParam = param.method.parameter.index(param.index - 1).headOption - if prevParam.isDefined then prevParam.toList - else List(method) - - private def previousOutputParamOrLastNodeOfBody(paramOut: MethodParameterOut) - : List[StoredNode] = - val prevParam = paramOut.method.parameter.index(paramOut.index - 1).asOutput.headOption - if prevParam.isDefined then prevParam.toList - else lastActualCfgNode.toList - - private def lastOutputParamOrLastNodeOfBody(): List[StoredNode] = - if lastOutputParam.isDefined then lastOutputParam.toList - else lastActualCfgNode.toList + private val logger: Logger = LoggerFactory.getLogger(this.getClass) + + val entryNode: StoredNode = method + val exitNode: StoredNode = method.methodReturn + + private val params = method.parameter.sortBy(_.index) + private val firstParam = params.headOption + private val lastParam = params.lastOption + private val firstOutputParam = firstParam.flatMap(_.asOutput.headOption) + private val lastOutputParam = method.parameter.sortBy(_.index).asOutput.lastOption + + private val lastActualCfgNode = exitNode._cfgIn.nextOption() + + val allNodesReversePostOrder: List[StoredNode] = + List(entryNode) ++ method.parameter.toList ++ method.reversePostOrder.toList.filter(x => + x.id != entryNode.id && x.id != exitNode.id + ) ++ method.parameter.asOutput.toList ++ List(exitNode) + + private val allNodesEvenUnreachable = + allNodesReversePostOrder ++ method.cfgNode.l.filterNot(x => + allNodesReversePostOrder.contains(x) + ) + val nodeToNumber: Map[StoredNode, Int] = + allNodesEvenUnreachable.zipWithIndex.map { case (x, i) => x -> i }.toMap + val numberToNode: Map[Int, StoredNode] = + allNodesEvenUnreachable.zipWithIndex.map { case (x, i) => i -> x }.toMap + + val allNodesPostOrder: List[StoredNode] = allNodesReversePostOrder.reverse + + private val _succ: Map[StoredNode, List[StoredNode]] = initSucc(allNodesReversePostOrder) + private val _pred: Map[StoredNode, List[StoredNode]] = + initPred(allNodesReversePostOrder, method) + + override def succ(node: StoredNode): IterableOnce[StoredNode] = + _succ.apply(node) + + override def pred(node: StoredNode): IterableOnce[StoredNode] = + _pred.apply(node) + + /** Create a map that allows CFG successors to be retrieved for each node + */ + private def initSucc(ns: List[StoredNode]): Map[StoredNode, List[StoredNode]] = + ns.map { + case n: Method => n -> firstParamOrBody(n) + case ret: Return => ret -> List(firstOutputParam.getOrElse(exitNode)) + case param: MethodParameterIn => + param -> nextParamOrBody(param) + case paramOut: MethodParameterOut => paramOut -> nextParamOutOrExit(paramOut) + case cfgNode: CfgNode => cfgNode -> cfgNextOrFirstOutParam(cfgNode) + case n => + logger.warn(s"Node type ${n.getClass.getSimpleName} should not be part of the CFG") + n -> List() + }.toMap + + /** Create a map that allows CFG predecessors to be retrieved for each node + */ + private def initPred(ns: List[StoredNode], method: Method): Map[StoredNode, List[StoredNode]] = + ns.map { + case param: MethodParameterIn => param -> previousParamOrEntry(param) + case paramOut: MethodParameterOut => + paramOut -> previousOutputParamOrLastNodeOfBody(paramOut) + case n: CfgNode if method.cfgFirst.headOption.contains(n) => + n -> List(lastParam.getOrElse(method)) + case n if n == exitNode => n -> lastOutputParamOrLastNodeOfBody() + case n @ (cfgNode: CfgNode) => n -> cfgNode.cfgPrev.l + case n => + logger.warn(s"Node type ${n.getClass.getSimpleName} should not be part of the CFG") + n -> List() + }.toMap + + private def firstParamOrBody(n: Method): List[StoredNode] = + if firstParam.isDefined then + firstParam.toList + else + cfgNext(n) + + private def cfgNext(n: CfgNode): List[StoredNode] = + n.out(EdgeTypes.CFG).map(_.asInstanceOf[StoredNode]).l + + private def nextParamOrBody(param: MethodParameterIn): List[StoredNode] = + val nextParam = param.method.parameter.index(param.index + 1).headOption + if nextParam.isDefined then nextParam.toList + else param.method.cfgFirst.l + + private def nextParamOutOrExit(paramOut: MethodParameterOut): List[StoredNode] = + val nextParam = paramOut.method.parameter.index(paramOut.index + 1).asOutput.headOption + if nextParam.isDefined then nextParam.toList + else List(exitNode) + + private def cfgNextOrFirstOutParam(cfgNode: CfgNode): List[StoredNode] = + // `.cfgNext` would be wrong here because it filters `METHOD_RETURN` + val successors = cfgNode.out(EdgeTypes.CFG).map(_.asInstanceOf[StoredNode]).l + if successors == List(exitNode) && firstOutputParam.isDefined then + List(firstOutputParam.get) + else + successors + + private def previousParamOrEntry(param: MethodParameterIn): List[StoredNode] = + val prevParam = param.method.parameter.index(param.index - 1).headOption + if prevParam.isDefined then prevParam.toList + else List(method) + + private def previousOutputParamOrLastNodeOfBody(paramOut: MethodParameterOut): List[StoredNode] = + val prevParam = paramOut.method.parameter.index(paramOut.index - 1).asOutput.headOption + if prevParam.isDefined then prevParam.toList + else lastActualCfgNode.toList + + private def lastOutputParamOrLastNodeOfBody(): List[StoredNode] = + if lastOutputParam.isDefined then lastOutputParam.toList + else lastActualCfgNode.toList end ReachingDefFlowGraph /** For each node of the graph, this transfer function defines how it affects the propagation of @@ -158,154 +157,152 @@ end ReachingDefFlowGraph class ReachingDefTransferFunction(flowGraph: ReachingDefFlowGraph) extends TransferFunction[StoredNode, mutable.BitSet]: - private val nodeToNumber = flowGraph.nodeToNumber - - val method: Method = flowGraph.method - - val gen: Map[StoredNode, mutable.BitSet] = - initGen(method).withDefaultValue(mutable.BitSet()) - - val kill: Map[StoredNode, mutable.BitSet] = - initKill(method, gen).withDefaultValue(mutable.BitSet()) - - /** For a given flow graph node `n` and set of definitions, apply the transfer function to - * obtain the updated set of definitions, considering `gen(n)` and `kill(n)`. - */ - override def apply(n: StoredNode, x: mutable.BitSet): mutable.BitSet = - gen(n).union(x.diff(kill(n))) - - /** Initialize the map `gen`, a map that contains generated definitions for each flow graph - * node. - */ - def initGen(method: Method): Map[StoredNode, mutable.BitSet] = - - val defsForParams = method.parameter.l.map { param => - param -> mutable.BitSet(Definition.fromNode( - param.asInstanceOf[StoredNode], - nodeToNumber - )) - } - - // We filter out field accesses to ensure that they propagate - // taint unharmed. - - val defsForCalls = method.call - .filterNot(x => isFieldAccess(x.name)) - .l - .map { call => - call -> { - val retVal = List(call) - val args = call.argument - .filter(hasValidGenType) - .l - mutable.BitSet( - (retVal ++ args) - .collect { - case x if nodeToNumber.contains(x) => - Definition.fromNode(x.asInstanceOf[StoredNode], nodeToNumber) - }* - ) - } - } - (defsForParams ++ defsForCalls).toMap - end initGen - - /** Restricts the types of nodes that represent definitions. - */ - private def hasValidGenType(node: Expression): Boolean = - node match - case _: Call => true - case _: Identifier => true - case _ => false - - /** Initialize the map `kill`, a map that contains killed definitions for each flow graph node. - * - * All operations in our graph are represented by calls and non-operations such as identifiers - * or field-identifiers have empty gen and kill sets, meaning that they just pass on - * definitions unaltered. - */ - private def initKill( - method: Method, - gen: Map[StoredNode, mutable.BitSet] - ): Map[StoredNode, mutable.BitSet] = - - val allIdentifiers: Map[String, List[CfgNode]] = - val results = mutable.Map.empty[String, List[CfgNode]] - method.ast - .collect { - case identifier: Identifier => - (identifier.name, identifier) - case methodParameterIn: MethodParameterIn => - (methodParameterIn.name, methodParameterIn) - } - .foreach { case (name, node) => - val oldValues = results.getOrElse(name, Nil) - results.put(name, node :: oldValues) - } - results.toMap - - val allCalls: Map[String, List[Call]] = - method.call.l - .groupBy(_.code) - .withDefaultValue(List.empty[Call]) - - // We filter out field accesses to ensure that they propagate - // taint unharmed. - - method.call - .filterNot(x => isGenericMemberAccessName(x.name)) - .map { call => - call -> killsForGens(gen(call), allIdentifiers, allCalls) + private val nodeToNumber = flowGraph.nodeToNumber + + val method: Method = flowGraph.method + + val gen: Map[StoredNode, mutable.BitSet] = + initGen(method).withDefaultValue(mutable.BitSet()) + + val kill: Map[StoredNode, mutable.BitSet] = + initKill(method, gen).withDefaultValue(mutable.BitSet()) + + /** For a given flow graph node `n` and set of definitions, apply the transfer function to obtain + * the updated set of definitions, considering `gen(n)` and `kill(n)`. + */ + override def apply(n: StoredNode, x: mutable.BitSet): mutable.BitSet = + gen(n).union(x.diff(kill(n))) + + /** Initialize the map `gen`, a map that contains generated definitions for each flow graph node. + */ + def initGen(method: Method): Map[StoredNode, mutable.BitSet] = + + val defsForParams = method.parameter.l.map { param => + param -> mutable.BitSet(Definition.fromNode( + param.asInstanceOf[StoredNode], + nodeToNumber + )) + } + + // We filter out field accesses to ensure that they propagate + // taint unharmed. + + val defsForCalls = method.call + .filterNot(x => isFieldAccess(x.name)) + .l + .map { call => + call -> { + val retVal = List(call) + val args = call.argument + .filter(hasValidGenType) + .l + mutable.BitSet( + (retVal ++ args) + .collect { + case x if nodeToNumber.contains(x) => + Definition.fromNode(x.asInstanceOf[StoredNode], nodeToNumber) + }* + ) } - .toMap - end initKill - - /** The only way in which a call can kill another definition is by generating a new definition - * for the same variable. Given the set of generated definitions `gens`, we calculate - * definitions of the same variable for each, that is, we calculate kill(call) based on - * gen(call). - */ - private def killsForGens( - genOfCall: mutable.BitSet, - allIdentifiers: Map[String, List[CfgNode]], - allCalls: Map[String, List[Call]] - ): mutable.BitSet = - - def definitionsOfSameVariable(definition: Definition): Iterator[Definition] = - val definedNodes = flowGraph.numberToNode(definition) match - case param: MethodParameterIn => - allIdentifiers(param.name).iterator - .filter(x => x.id != param.id) - case identifier: Identifier => - val sameIdentifiers = allIdentifiers(identifier.name).iterator - .filter(x => x.id != identifier.id) - - /** Killing an identifier should also kill field accesses on that identifier. - * For example, a reassignment `x = new Box()` should kill any previous calls - * to `x.value`, `x.length()`, etc. - */ - val sameObjects: Iterator[Call] = allCalls.valuesIterator.flatten - .filter(_.name == Operators.fieldAccess) - .filter(_.ast.isIdentifier.nameExact(identifier.name).nonEmpty) - - sameIdentifiers ++ sameObjects - case call: Call => - allCalls(call.code).iterator - .filter(x => x.id != call.id) - case _ => Iterator.empty - definedNodes - // It can happen that the CFG is broken and contains isolated nodes, - // in which case they are not in `nodeToNumber`. Let's filter those. - .collect { - case x if nodeToNumber.contains(x) => Definition.fromNode(x, nodeToNumber) - } - end definitionsOfSameVariable - - val res = mutable.BitSet() - for definition <- genOfCall do - res.addAll(definitionsOfSameVariable(definition)) - res - end killsForGens + } + (defsForParams ++ defsForCalls).toMap + end initGen + + /** Restricts the types of nodes that represent definitions. + */ + private def hasValidGenType(node: Expression): Boolean = + node match + case _: Call => true + case _: Identifier => true + case _ => false + + /** Initialize the map `kill`, a map that contains killed definitions for each flow graph node. + * + * All operations in our graph are represented by calls and non-operations such as identifiers or + * field-identifiers have empty gen and kill sets, meaning that they just pass on definitions + * unaltered. + */ + private def initKill( + method: Method, + gen: Map[StoredNode, mutable.BitSet] + ): Map[StoredNode, mutable.BitSet] = + + val allIdentifiers: Map[String, List[CfgNode]] = + val results = mutable.Map.empty[String, List[CfgNode]] + method.ast + .collect { + case identifier: Identifier => + (identifier.name, identifier) + case methodParameterIn: MethodParameterIn => + (methodParameterIn.name, methodParameterIn) + } + .foreach { case (name, node) => + val oldValues = results.getOrElse(name, Nil) + results.put(name, node :: oldValues) + } + results.toMap + + val allCalls: Map[String, List[Call]] = + method.call.l + .groupBy(_.code) + .withDefaultValue(List.empty[Call]) + + // We filter out field accesses to ensure that they propagate + // taint unharmed. + + method.call + .filterNot(x => isGenericMemberAccessName(x.name)) + .map { call => + call -> killsForGens(gen(call), allIdentifiers, allCalls) + } + .toMap + end initKill + + /** The only way in which a call can kill another definition is by generating a new definition for + * the same variable. Given the set of generated definitions `gens`, we calculate definitions of + * the same variable for each, that is, we calculate kill(call) based on gen(call). + */ + private def killsForGens( + genOfCall: mutable.BitSet, + allIdentifiers: Map[String, List[CfgNode]], + allCalls: Map[String, List[Call]] + ): mutable.BitSet = + + def definitionsOfSameVariable(definition: Definition): Iterator[Definition] = + val definedNodes = flowGraph.numberToNode(definition) match + case param: MethodParameterIn => + allIdentifiers(param.name).iterator + .filter(x => x.id != param.id) + case identifier: Identifier => + val sameIdentifiers = allIdentifiers(identifier.name).iterator + .filter(x => x.id != identifier.id) + + /** Killing an identifier should also kill field accesses on that identifier. For + * example, a reassignment `x = new Box()` should kill any previous calls to `x.value`, + * `x.length()`, etc. + */ + val sameObjects: Iterator[Call] = allCalls.valuesIterator.flatten + .filter(_.name == Operators.fieldAccess) + .filter(_.ast.isIdentifier.nameExact(identifier.name).nonEmpty) + + sameIdentifiers ++ sameObjects + case call: Call => + allCalls(call.code).iterator + .filter(x => x.id != call.id) + case _ => Iterator.empty + definedNodes + // It can happen that the CFG is broken and contains isolated nodes, + // in which case they are not in `nodeToNumber`. Let's filter those. + .collect { + case x if nodeToNumber.contains(x) => Definition.fromNode(x, nodeToNumber) + } + end definitionsOfSameVariable + + val res = mutable.BitSet() + for definition <- genOfCall do + res.addAll(definitionsOfSameVariable(definition)) + res + end killsForGens end ReachingDefTransferFunction /** Lone Identifier Optimization: we first determine and store all identifiers that neither refer to @@ -318,50 +315,50 @@ end ReachingDefTransferFunction class OptimizedReachingDefTransferFunction(flowGraph: ReachingDefFlowGraph) extends ReachingDefTransferFunction(flowGraph): - lazy val loneIdentifiers: Map[Call, List[Definition]] = - val identifiersInReturns = method._returnViaContainsOut.ast.isIdentifier.name.l - val paramAndLocalNames = method.parameter.name.l ++ method.local.name.l - val callArgPairs = method.call.flatMap { call => - call.argument.isIdentifier - .filterNot(i => paramAndLocalNames.contains(i.name)) - .filterNot(i => identifiersInReturns.contains(i.name)) - .map(arg => (arg.name, call, arg)) - }.l - - callArgPairs - .groupBy(_._1) - .collect { - case (_, v) if v.size == 1 => v.map { case (_, call, arg) => (call, arg) }.head - } - .toList - .groupBy(_._1) - .map { case (k, v) => - ( - k, - v.filter(x => flowGraph.nodeToNumber.contains(x._2)) - .map(x => Definition.fromNode(x._2, flowGraph.nodeToNumber)) - ) - } - end loneIdentifiers - - override def initGen(method: Method): Map[StoredNode, mutable.BitSet] = - withoutLoneIdentifiers(super.initGen(method)) - - private def withoutLoneIdentifiers(g: Map[StoredNode, mutable.BitSet]) - : Map[StoredNode, mutable.BitSet] = - g.map { case (k, defs) => - k match - case call: Call if loneIdentifiers.contains(call) => - (call, defs.filterNot(loneIdentifiers(call).contains(_))) - case _ => (k, defs) + lazy val loneIdentifiers: Map[Call, List[Definition]] = + val identifiersInReturns = method._returnViaContainsOut.ast.isIdentifier.name.l + val paramAndLocalNames = method.parameter.name.l ++ method.local.name.l + val callArgPairs = method.call.flatMap { call => + call.argument.isIdentifier + .filterNot(i => paramAndLocalNames.contains(i.name)) + .filterNot(i => identifiersInReturns.contains(i.name)) + .map(arg => (arg.name, call, arg)) + }.l + + callArgPairs + .groupBy(_._1) + .collect { + case (_, v) if v.size == 1 => v.map { case (_, call, arg) => (call, arg) }.head + } + .toList + .groupBy(_._1) + .map { case (k, v) => + ( + k, + v.filter(x => flowGraph.nodeToNumber.contains(x._2)) + .map(x => Definition.fromNode(x._2, flowGraph.nodeToNumber)) + ) } + end loneIdentifiers + + override def initGen(method: Method): Map[StoredNode, mutable.BitSet] = + withoutLoneIdentifiers(super.initGen(method)) + + private def withoutLoneIdentifiers(g: Map[StoredNode, mutable.BitSet]) + : Map[StoredNode, mutable.BitSet] = + g.map { case (k, defs) => + k match + case call: Call if loneIdentifiers.contains(call) => + (call, defs.filterNot(loneIdentifiers(call).contains(_))) + case _ => (k, defs) + } end OptimizedReachingDefTransferFunction class ReachingDefInit(gen: Map[StoredNode, mutable.BitSet]) extends InOutInit[StoredNode, mutable.BitSet]: - override def initIn: Map[StoredNode, mutable.BitSet] = - Map - .empty[StoredNode, mutable.BitSet] - .withDefaultValue(mutable.BitSet()) + override def initIn: Map[StoredNode, mutable.BitSet] = + Map + .empty[StoredNode, mutable.BitSet] + .withDefaultValue(mutable.BitSet()) - override def initOut: Map[StoredNode, mutable.BitSet] = gen + override def initOut: Map[StoredNode, mutable.BitSet] = gen diff --git a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/passes/reachingdef/package.scala b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/passes/reachingdef/package.scala index 1e2de5dd..a2d6695b 100644 --- a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/passes/reachingdef/package.scala +++ b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/passes/reachingdef/package.scala @@ -1,4 +1,4 @@ package io.appthreat.dataflowengineoss.passes package object reachingdef: - type Definition = Int + type Definition = Int diff --git a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/queryengine/AccessPathUsage.scala b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/queryengine/AccessPathUsage.scala index 9588b1fa..3dca7266 100644 --- a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/queryengine/AccessPathUsage.scala +++ b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/queryengine/AccessPathUsage.scala @@ -17,40 +17,40 @@ import org.slf4j.LoggerFactory object AccessPathUsage: - private val logger = LoggerFactory.getLogger(getClass) + private val logger = LoggerFactory.getLogger(getClass) - def toTrackedBaseAndAccessPathSimple(node: StoredNode): (TrackedBase, AccessPath) = - val (base, revPath) = toTrackedBaseAndAccessPathInternal(node) - (base, AccessPath.apply(Elements.normalized(revPath.reverse), Nil)) + def toTrackedBaseAndAccessPathSimple(node: StoredNode): (TrackedBase, AccessPath) = + val (base, revPath) = toTrackedBaseAndAccessPathInternal(node) + (base, AccessPath.apply(Elements.normalized(revPath.reverse), Nil)) - private def toTrackedBaseAndAccessPathInternal(node: StoredNode) - : (TrackedBase, List[AccessElement]) = - val result = AccessPathHandling.leafToTrackedBaseAndAccessPathInternal(node) - if result.isDefined then - result.get - else - node match + private def toTrackedBaseAndAccessPathInternal(node: StoredNode) + : (TrackedBase, List[AccessElement]) = + val result = AccessPathHandling.leafToTrackedBaseAndAccessPathInternal(node) + if result.isDefined then + result.get + else + node match - case block: Block => - AccessPathHandling - .lastExpressionInBlock(block) - .map { toTrackedBaseAndAccessPathInternal } - .getOrElse((TrackedUnknown, Nil)) - case call: Call if !MemberAccess.isGenericMemberAccessName(call.name) => - (TrackedReturnValue(call), Nil) + case block: Block => + AccessPathHandling + .lastExpressionInBlock(block) + .map { toTrackedBaseAndAccessPathInternal } + .getOrElse((TrackedUnknown, Nil)) + case call: Call if !MemberAccess.isGenericMemberAccessName(call.name) => + (TrackedReturnValue(call), Nil) - case memberAccess: Call => - // assume: MemberAccess.isGenericMemberAccessName(call.name) - val argOne = memberAccess.argumentOption(1) - if argOne.isEmpty then - logger.debug(s"Missing first argument on call ${memberAccess.code}.") - return (TrackedUnknown, Nil) - val (base, tail) = toTrackedBaseAndAccessPathInternal(argOne.get) - val path = AccessPathHandling.memberAccessToPath(memberAccess, tail) - (base, path) - case _ => - logger.debug(s"Missing handling for node type ${node.getClass}.") - (TrackedUnknown, Nil) - end if - end toTrackedBaseAndAccessPathInternal + case memberAccess: Call => + // assume: MemberAccess.isGenericMemberAccessName(call.name) + val argOne = memberAccess.argumentOption(1) + if argOne.isEmpty then + logger.debug(s"Missing first argument on call ${memberAccess.code}.") + return (TrackedUnknown, Nil) + val (base, tail) = toTrackedBaseAndAccessPathInternal(argOne.get) + val path = AccessPathHandling.memberAccessToPath(memberAccess, tail) + (base, path) + case _ => + logger.debug(s"Missing handling for node type ${node.getClass}.") + (TrackedUnknown, Nil) + end if + end toTrackedBaseAndAccessPathInternal end AccessPathUsage diff --git a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/queryengine/Engine.scala b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/queryengine/Engine.scala index 97822abe..cc107bf0 100644 --- a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/queryengine/Engine.scala +++ b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/queryengine/Engine.scala @@ -24,270 +24,269 @@ import scala.util.{Failure, Success, Try} */ class Engine(context: EngineContext): - import Engine.* - - private val logger: Logger = LoggerFactory.getLogger(this.getClass) - private val executorService: ExecutorService = - Executors.newVirtualThreadPerTaskExecutor() - private val completionService = - new ExecutorCompletionService[TaskSummary](executorService) - - /** All results of tasks are accumulated in this table. At the end of the analysis, we extract - * results from the table and return them. - */ - private val mainResultTable: mutable.Map[TaskFingerprint, List[TableEntry]] = mutable.Map() - private var numberOfTasksRunning: Int = 0 - private val started: mutable.HashSet[TaskFingerprint] = mutable.HashSet[TaskFingerprint]() - private val held: mutable.Buffer[ReachableByTask] = mutable.Buffer() - - /** Determine flows from sources to sinks by exploring the graph backwards from sinks to - * sources. Returns the list of results along with a ResultTable, a cache of known paths - * created during the analysis. - */ - def backwards(sinks: List[CfgNode], sources: List[CfgNode]): List[TableEntry] = - if sources.isEmpty then - logger.debug("Attempting to determine flows from empty list of sources.") - if sinks.isEmpty then - logger.debug("Attempting to determine flows to empty list of sinks.") - reset() - val sourcesSet = sources.toSet - val tasks = createOneTaskPerSink(sinks) - solveTasks(tasks, sourcesSet, sinks) - - private def reset(): Unit = - mainResultTable.clear() - numberOfTasksRunning = 0 - started.clear() - held.clear() - - private def createOneTaskPerSink(sinks: List[CfgNode]) = - sinks.map(sink => ReachableByTask(List(TaskFingerprint(sink, List(), 0)), Vector())) - - /** Submit tasks to a worker pool, solving them in parallel. Upon receiving results for a task, - * new tasks are submitted accordingly. Once no more tasks can be created, the list of results - * is returned. + import Engine.* + + private val logger: Logger = LoggerFactory.getLogger(this.getClass) + private val executorService: ExecutorService = + Executors.newVirtualThreadPerTaskExecutor() + private val completionService = + new ExecutorCompletionService[TaskSummary](executorService) + + /** All results of tasks are accumulated in this table. At the end of the analysis, we extract + * results from the table and return them. + */ + private val mainResultTable: mutable.Map[TaskFingerprint, List[TableEntry]] = mutable.Map() + private var numberOfTasksRunning: Int = 0 + private val started: mutable.HashSet[TaskFingerprint] = mutable.HashSet[TaskFingerprint]() + private val held: mutable.Buffer[ReachableByTask] = mutable.Buffer() + + /** Determine flows from sources to sinks by exploring the graph backwards from sinks to sources. + * Returns the list of results along with a ResultTable, a cache of known paths created during + * the analysis. + */ + def backwards(sinks: List[CfgNode], sources: List[CfgNode]): List[TableEntry] = + if sources.isEmpty then + logger.debug("Attempting to determine flows from empty list of sources.") + if sinks.isEmpty then + logger.debug("Attempting to determine flows to empty list of sinks.") + reset() + val sourcesSet = sources.toSet + val tasks = createOneTaskPerSink(sinks) + solveTasks(tasks, sourcesSet, sinks) + + private def reset(): Unit = + mainResultTable.clear() + numberOfTasksRunning = 0 + started.clear() + held.clear() + + private def createOneTaskPerSink(sinks: List[CfgNode]) = + sinks.map(sink => ReachableByTask(List(TaskFingerprint(sink, List(), 0)), Vector())) + + /** Submit tasks to a worker pool, solving them in parallel. Upon receiving results for a task, + * new tasks are submitted accordingly. Once no more tasks can be created, the list of results is + * returned. + */ + private def solveTasks( + tasks: List[ReachableByTask], + sources: Set[CfgNode], + sinks: List[CfgNode] + ): List[TableEntry] = + + /** Solving a task produces a list of summaries. The following method is called for each of + * these summaries. It submits new tasks and adds results to the result table. */ - private def solveTasks( - tasks: List[ReachableByTask], - sources: Set[CfgNode], - sinks: List[CfgNode] - ): List[TableEntry] = - - /** Solving a task produces a list of summaries. The following method is called for each of - * these summaries. It submits new tasks and adds results to the result table. - */ - def handleSummary(taskSummary: TaskSummary): Unit = - val newTasks = taskSummary.followupTasks - submitTasks(newTasks, sources) - val newResults = taskSummary.tableEntries - addEntriesToMainTable(newResults) - - def addEntriesToMainTable(entries: Vector[(TaskFingerprint, TableEntry)]): Unit = - entries.groupBy(_._1).foreach { case (fingerprint, entryList) => - val entries = entryList.map(_._2).toList - mainResultTable.updateWith(fingerprint) { - case Some(list) => Some(list ++ entries) - case None => Some(entries) - } + def handleSummary(taskSummary: TaskSummary): Unit = + val newTasks = taskSummary.followupTasks + submitTasks(newTasks, sources) + val newResults = taskSummary.tableEntries + addEntriesToMainTable(newResults) + + def addEntriesToMainTable(entries: Vector[(TaskFingerprint, TableEntry)]): Unit = + entries.groupBy(_._1).foreach { case (fingerprint, entryList) => + val entries = entryList.map(_._2).toList + mainResultTable.updateWith(fingerprint) { + case Some(list) => Some(list ++ entries) + case None => Some(entries) } - - def runUntilAllTasksAreSolved(): Unit = - while numberOfTasksRunning > 0 do - Try { - completionService.take.get - } match - case Success(resultsOfTask) => - numberOfTasksRunning -= 1 - handleSummary(resultsOfTask) - case Failure(_) => - numberOfTasksRunning -= 1 - - submitTasks(tasks.toVector, sources) - val startTimeSec: Long = System.currentTimeMillis / 1000 - runUntilAllTasksAreSolved() - val taskFinishTimeSec: Long = System.currentTimeMillis / 1000 - logger.debug( - "Time measurement -----> Task processing done in " + - (taskFinishTimeSec - startTimeSec) + " seconds" - ) - new HeldTaskCompletion(held.toList, mainResultTable).completeHeldTasks() - val dedupResult = deduplicateFinal(extractResultsFromTable(sinks)) - val allDoneTimeSec: Long = System.currentTimeMillis / 1000 - - logger.debug( - "Time measurement -----> Task processing: " + - (taskFinishTimeSec - startTimeSec) + " seconds" + - ", Deduplication: " + (allDoneTimeSec - taskFinishTimeSec) + - ", Deduped results size: " + dedupResult.length - ) - dedupResult - end solveTasks - - private def submitTasks(tasks: Vector[ReachableByTask], sources: Set[CfgNode]): Unit = - tasks.foreach { task => - if started.contains(task.fingerprint) then - held ++= Vector(task) - else - started.add(task.fingerprint) - numberOfTasksRunning += 1 - completionService.submit(new TaskSolver(task, context, sources)) } - private def extractResultsFromTable(sinks: List[CfgNode]): List[TableEntry] = - sinks.flatMap { sink => - mainResultTable.get(TaskFingerprint(sink, List(), 0)) match - case Some(results) => results - case _ => Vector() - } - - private def deduplicateFinal(list: List[TableEntry]): List[TableEntry] = - list - .groupBy { result => - val head = result.path.head.node - val last = result.path.last.node - (head, last) - } - .map { case (_, list) => - val lenIdPathPairs = list.map(x => (x.path.length, x)) - val withMaxLength = (lenIdPathPairs.sortBy(_._1).reverse match - case Nil => Nil - case h :: t => h :: t.takeWhile(y => y._1 == h._1) - ).map(_._2) - - if withMaxLength.length == 1 then - withMaxLength.head - else - withMaxLength.minBy { x => - x.path - .map(x => - ( - x.node.id, - x.callSiteStack.map(_.id), - x.visible, - x.isOutputArg, - x.outEdgeLabel - ).toString - ) - .mkString("-") - } - } - .toList + def runUntilAllTasksAreSolved(): Unit = + while numberOfTasksRunning > 0 do + Try { + completionService.take.get + } match + case Success(resultsOfTask) => + numberOfTasksRunning -= 1 + handleSummary(resultsOfTask) + case Failure(_) => + numberOfTasksRunning -= 1 + + submitTasks(tasks.toVector, sources) + val startTimeSec: Long = System.currentTimeMillis / 1000 + runUntilAllTasksAreSolved() + val taskFinishTimeSec: Long = System.currentTimeMillis / 1000 + logger.debug( + "Time measurement -----> Task processing done in " + + (taskFinishTimeSec - startTimeSec) + " seconds" + ) + new HeldTaskCompletion(held.toList, mainResultTable).completeHeldTasks() + val dedupResult = deduplicateFinal(extractResultsFromTable(sinks)) + val allDoneTimeSec: Long = System.currentTimeMillis / 1000 + + logger.debug( + "Time measurement -----> Task processing: " + + (taskFinishTimeSec - startTimeSec) + " seconds" + + ", Deduplication: " + (allDoneTimeSec - taskFinishTimeSec) + + ", Deduped results size: " + dedupResult.length + ) + dedupResult + end solveTasks + + private def submitTasks(tasks: Vector[ReachableByTask], sources: Set[CfgNode]): Unit = + tasks.foreach { task => + if started.contains(task.fingerprint) then + held ++= Vector(task) + else + started.add(task.fingerprint) + numberOfTasksRunning += 1 + completionService.submit(new TaskSolver(task, context, sources)) + } + + private def extractResultsFromTable(sinks: List[CfgNode]): List[TableEntry] = + sinks.flatMap { sink => + mainResultTable.get(TaskFingerprint(sink, List(), 0)) match + case Some(results) => results + case _ => Vector() + } + + private def deduplicateFinal(list: List[TableEntry]): List[TableEntry] = + list + .groupBy { result => + val head = result.path.head.node + val last = result.path.last.node + (head, last) + } + .map { case (_, list) => + val lenIdPathPairs = list.map(x => (x.path.length, x)) + val withMaxLength = (lenIdPathPairs.sortBy(_._1).reverse match + case Nil => Nil + case h :: t => h :: t.takeWhile(y => y._1 == h._1) + ).map(_._2) + + if withMaxLength.length == 1 then + withMaxLength.head + else + withMaxLength.minBy { x => + x.path + .map(x => + ( + x.node.id, + x.callSiteStack.map(_.id), + x.visible, + x.isOutputArg, + x.outEdgeLabel + ).toString + ) + .mkString("-") + } + } + .toList - /** This must be called when one is done using the engine. - */ - def shutdown(): Unit = - executorService.shutdown() + /** This must be called when one is done using the engine. + */ + def shutdown(): Unit = + executorService.shutdown() end Engine object Engine: - /** Traverse from a node to incoming DDG nodes, taking into account semantics. This method is - * exposed via the `ddgIn` step, but is also called by the engine internally by the - * `TaskSolver`. - * - * @param curNode - * the node to expand - * @param path - * the path that has been expanded to reach the `curNode` - */ - def expandIn( - curNode: CfgNode, - path: Vector[PathElement], - callSiteStack: List[Call] = List() - )(implicit semantics: Semantics): Vector[PathElement] = - ddgInE(curNode, path, callSiteStack).flatMap(x => elemForEdge(x, callSiteStack)) - - private def elemForEdge(e: Edge, callSiteStack: List[Call] = List())(implicit - semantics: Semantics - ): Option[PathElement] = - val curNode = e.inNode().asInstanceOf[CfgNode] - val parNode = e.outNode().asInstanceOf[CfgNode] - val outLabel = Some(e.property(Properties.VARIABLE)).getOrElse("") - - if !EdgeValidator.isValidEdge(curNode, parNode) then - return None - - curNode match - case childNode: Expression => - parNode match - case parentNode: Expression => - val parentNodeCall = parentNode.inCall.l - val sameCallSite = parentNode.inCall.l == childNode.start.inCall.l - val visible = if sameCallSite then - val semanticExists = parentNode.semanticsForCallByArg.nonEmpty - val internalMethodsForCall = - parentNodeCall.flatMap(methodsForCall).internal - (semanticExists && parentNode.isDefined) || internalMethodsForCall.isEmpty - else - parentNode.isDefined - val isOutputArg = isOutputArgOfInternalMethod(parentNode) - Some(PathElement( - parentNode, - callSiteStack, - visible, - isOutputArg, - outEdgeLabel = outLabel - )) - case parentNode if parentNode != null => - Some(PathElement(parentNode, callSiteStack, outEdgeLabel = outLabel)) - case null => - None - case _ => - Some(PathElement(parNode, callSiteStack, outEdgeLabel = outLabel)) - end match - end elemForEdge - - def isOutputArgOfInternalMethod(arg: Expression)(implicit semantics: Semantics): Boolean = - arg.inCall.l match - case List(call) => - methodsForCall(call).internal.isNotStub.nonEmpty && semanticsForCall(call).isEmpty - case _ => - false - - /** For a given node `node`, return all incoming reaching definition edges, unless the source - * node is (a) a METHOD node, (b) already present on `path`, or (c) a CALL node to a method - * where the semantic indicates that taint is propagated to it. - */ - private def ddgInE( - node: CfgNode, - path: Vector[PathElement], - callSiteStack: List[Call] = List() - ): Vector[Edge] = - node - .inE(EdgeTypes.REACHING_DEF) - .asScala - .filter { e => - e.outNode() match - case srcNode: CfgNode => - !srcNode.isInstanceOf[Method] && !path - .map(x => x.node) - .contains(srcNode) - case _ => false - } - .toVector - - def argToOutputParams(arg: Expression): Iterator[MethodParameterOut] = - argToMethods(arg).parameter - .index(arg.argumentIndex) - .asOutput - - def argToMethods(arg: Expression): List[Method] = - arg.inCall.l.flatMap { call => - methodsForCall(call) - } - - def methodsForCall(call: Call): List[Method] = - NoResolve.getCalledMethods(call).toList - - def isCallToInternalMethod(call: Call): Boolean = - methodsForCall(call).internal.nonEmpty - def isCallToInternalMethodWithoutSemantic(call: Call)(implicit semantics: Semantics): Boolean = - isCallToInternalMethod(call) && semanticsForCall(call).isEmpty - - def semanticsForCall(call: Call)(implicit semantics: Semantics): List[FlowSemantic] = - Engine.methodsForCall(call).flatMap { method => - semantics.forMethod(method.fullName) - } + /** Traverse from a node to incoming DDG nodes, taking into account semantics. This method is + * exposed via the `ddgIn` step, but is also called by the engine internally by the `TaskSolver`. + * + * @param curNode + * the node to expand + * @param path + * the path that has been expanded to reach the `curNode` + */ + def expandIn( + curNode: CfgNode, + path: Vector[PathElement], + callSiteStack: List[Call] = List() + )(implicit semantics: Semantics): Vector[PathElement] = + ddgInE(curNode, path, callSiteStack).flatMap(x => elemForEdge(x, callSiteStack)) + + private def elemForEdge(e: Edge, callSiteStack: List[Call] = List())(implicit + semantics: Semantics + ): Option[PathElement] = + val curNode = e.inNode().asInstanceOf[CfgNode] + val parNode = e.outNode().asInstanceOf[CfgNode] + val outLabel = Some(e.property(Properties.VARIABLE)).getOrElse("") + + if !EdgeValidator.isValidEdge(curNode, parNode) then + return None + + curNode match + case childNode: Expression => + parNode match + case parentNode: Expression => + val parentNodeCall = parentNode.inCall.l + val sameCallSite = parentNode.inCall.l == childNode.start.inCall.l + val visible = if sameCallSite then + val semanticExists = parentNode.semanticsForCallByArg.nonEmpty + val internalMethodsForCall = + parentNodeCall.flatMap(methodsForCall).internal + (semanticExists && parentNode.isDefined) || internalMethodsForCall.isEmpty + else + parentNode.isDefined + val isOutputArg = isOutputArgOfInternalMethod(parentNode) + Some(PathElement( + parentNode, + callSiteStack, + visible, + isOutputArg, + outEdgeLabel = outLabel + )) + case parentNode if parentNode != null => + Some(PathElement(parentNode, callSiteStack, outEdgeLabel = outLabel)) + case null => + None + case _ => + Some(PathElement(parNode, callSiteStack, outEdgeLabel = outLabel)) + end match + end elemForEdge + + def isOutputArgOfInternalMethod(arg: Expression)(implicit semantics: Semantics): Boolean = + arg.inCall.l match + case List(call) => + methodsForCall(call).internal.isNotStub.nonEmpty && semanticsForCall(call).isEmpty + case _ => + false + + /** For a given node `node`, return all incoming reaching definition edges, unless the source node + * is (a) a METHOD node, (b) already present on `path`, or (c) a CALL node to a method where the + * semantic indicates that taint is propagated to it. + */ + private def ddgInE( + node: CfgNode, + path: Vector[PathElement], + callSiteStack: List[Call] = List() + ): Vector[Edge] = + node + .inE(EdgeTypes.REACHING_DEF) + .asScala + .filter { e => + e.outNode() match + case srcNode: CfgNode => + !srcNode.isInstanceOf[Method] && !path + .map(x => x.node) + .contains(srcNode) + case _ => false + } + .toVector + + def argToOutputParams(arg: Expression): Iterator[MethodParameterOut] = + argToMethods(arg).parameter + .index(arg.argumentIndex) + .asOutput + + def argToMethods(arg: Expression): List[Method] = + arg.inCall.l.flatMap { call => + methodsForCall(call) + } + + def methodsForCall(call: Call): List[Method] = + NoResolve.getCalledMethods(call).toList + + def isCallToInternalMethod(call: Call): Boolean = + methodsForCall(call).internal.nonEmpty + def isCallToInternalMethodWithoutSemantic(call: Call)(implicit semantics: Semantics): Boolean = + isCallToInternalMethod(call) && semanticsForCall(call).isEmpty + + def semanticsForCall(call: Call)(implicit semantics: Semantics): List[FlowSemantic] = + Engine.methodsForCall(call).flatMap { method => + semantics.forMethod(method.fullName) + } end Engine /** The execution context for the data flow engine. @@ -326,31 +325,31 @@ case class EngineConfig( */ object QueryEngineStatistics extends Enumeration: - type QueryEngineStatistic = Value + type QueryEngineStatistic = Value - val PATH_CACHE_HITS, PATH_CACHE_MISSES = Value + val PATH_CACHE_HITS, PATH_CACHE_MISSES = Value - private val statistics = new ConcurrentHashMap[QueryEngineStatistic, Long]() + private val statistics = new ConcurrentHashMap[QueryEngineStatistic, Long]() - reset() + reset() - /** Adds the given value to the associated value to the given [[QueryEngineStatistics]] key. - * @param key - * the key associated with the value to transform. - * @param value - * the value to add to the statistic. Can be negative. - */ - def incrementBy(key: QueryEngineStatistic, value: Long): Unit = - statistics.put(key, statistics.getOrDefault(key, 0L) + value) + /** Adds the given value to the associated value to the given [[QueryEngineStatistics]] key. + * @param key + * the key associated with the value to transform. + * @param value + * the value to add to the statistic. Can be negative. + */ + def incrementBy(key: QueryEngineStatistic, value: Long): Unit = + statistics.put(key, statistics.getOrDefault(key, 0L) + value) - /** The results of the measured statistics. - * @return - * a map of each [[QueryEngineStatistic]] and the associated value measurement. - */ - def results(): Map[QueryEngineStatistic, Long] = statistics.asScala.toMap + /** The results of the measured statistics. + * @return + * a map of each [[QueryEngineStatistic]] and the associated value measurement. + */ + def results(): Map[QueryEngineStatistic, Long] = statistics.asScala.toMap - /** Sets all the tracked values back to 0. - */ - def reset(): Unit = - QueryEngineStatistics.values.map((_, 0L)).foreach { case (v, t) => statistics.put(v, t) } + /** Sets all the tracked values back to 0. + */ + def reset(): Unit = + QueryEngineStatistics.values.map((_, 0L)).foreach { case (v, t) => statistics.put(v, t) } end QueryEngineStatistics diff --git a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/queryengine/HeldTaskCompletion.scala b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/queryengine/HeldTaskCompletion.scala index c268671e..11bd586d 100644 --- a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/queryengine/HeldTaskCompletion.scala +++ b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/queryengine/HeldTaskCompletion.scala @@ -22,159 +22,159 @@ class HeldTaskCompletion( resultTable: mutable.Map[TaskFingerprint, List[TableEntry]] ): - /** Add results produced by held task until no more change can be observed. - * - * We use the following simple algorithm (that can possibly be optimized in the future): - * - * For each held `task`, we keep a Boolean `changed(task)`, which indicates whether new results - * for the `task` were produced. We initialize the Booleans to be true. Computation is - * terminated when all Booleans are false, that is, when no more changes in the result table - * can be observed. - * - * If we do detect a change, we determine all tasks for which changed results exist and - * recompute their results. We compare the results with those produced previously (stored in - * `resultsProducedByTask`). If any new results were created, `changed` is set to true for the - * result's table entry and `resultsProductByTask` is updated. - */ - def completeHeldTasks(): Unit = + /** Add results produced by held task until no more change can be observed. + * + * We use the following simple algorithm (that can possibly be optimized in the future): + * + * For each held `task`, we keep a Boolean `changed(task)`, which indicates whether new results + * for the `task` were produced. We initialize the Booleans to be true. Computation is terminated + * when all Booleans are false, that is, when no more changes in the result table can be + * observed. + * + * If we do detect a change, we determine all tasks for which changed results exist and recompute + * their results. We compare the results with those produced previously (stored in + * `resultsProducedByTask`). If any new results were created, `changed` is set to true for the + * result's table entry and `resultsProductByTask` is updated. + */ + def completeHeldTasks(): Unit = - deduplicateResultTable() - val toProcess = - heldTasks.distinct.sortBy(x => - (x.fingerprint.sink.id, x.fingerprint.callSiteStack.map(_.id).toString, x.callDepth) - ) - var resultsProducedByTask: Map[ReachableByTask, Set[(TaskFingerprint, TableEntry)]] = Map() + deduplicateResultTable() + val toProcess = + heldTasks.distinct.sortBy(x => + (x.fingerprint.sink.id, x.fingerprint.callSiteStack.map(_.id).toString, x.callDepth) + ) + var resultsProducedByTask: Map[ReachableByTask, Set[(TaskFingerprint, TableEntry)]] = Map() - def allChanged = toProcess.map { task => task.fingerprint -> true }.toMap - def noneChanged = toProcess.map { t => t.fingerprint -> false }.toMap + def allChanged = toProcess.map { task => task.fingerprint -> true }.toMap + def noneChanged = toProcess.map { t => t.fingerprint -> false }.toMap - var changed: Map[TaskFingerprint, Boolean] = allChanged + var changed: Map[TaskFingerprint, Boolean] = allChanged - while changed.values.toList.contains(true) do - val taskResultsPairs = toProcess - .filter(t => changed(t.fingerprint)) - .par - .map { t => - val resultsForTask = resultsForHeldTask(t).toSet - val newResults = resultsForTask -- resultsProducedByTask.getOrElse(t, Set()) - (t, resultsForTask, newResults) - } - .filter { case (_, _, newResults) => newResults.nonEmpty } - .seq + while changed.values.toList.contains(true) do + val taskResultsPairs = toProcess + .filter(t => changed(t.fingerprint)) + .par + .map { t => + val resultsForTask = resultsForHeldTask(t).toSet + val newResults = resultsForTask -- resultsProducedByTask.getOrElse(t, Set()) + (t, resultsForTask, newResults) + } + .filter { case (_, _, newResults) => newResults.nonEmpty } + .seq + + changed = noneChanged + taskResultsPairs.foreach { case (t, resultsForTask, newResults) => + addCompletedTasksToMainTable(newResults.toList) + newResults.foreach { case (fingerprint, _) => + changed += fingerprint -> true + } + resultsProducedByTask += (t -> resultsForTask) + } + end while + deduplicateResultTable() + end completeHeldTasks - changed = noneChanged - taskResultsPairs.foreach { case (t, resultsForTask, newResults) => - addCompletedTasksToMainTable(newResults.toList) - newResults.foreach { case (fingerprint, _) => - changed += fingerprint -> true + /** In essence, completing a held task simply means appending the path stored in the held task to + * all results that are available for the held task in the table. In practice, we create one + * result for each task of the parent task's `taskStack`, so that we do not only get a new result + * for the sink, but one for each of the parent nodes on the way. + */ + private def resultsForHeldTask(heldTask: ReachableByTask): List[(TaskFingerprint, TableEntry)] = + // Create a flat list of results by computing results for each + // table entry and appending them. + resultTable.get(heldTask.fingerprint) match + case Some(results) => + results + .flatMap { r => + createResultsForHeldTaskAndTableResult(heldTask, r) } - resultsProducedByTask += (t -> resultsForTask) - } - end while - deduplicateResultTable() - end completeHeldTasks + case None => List() - /** In essence, completing a held task simply means appending the path stored in the held task - * to all results that are available for the held task in the table. In practice, we create one - * result for each task of the parent task's `taskStack`, so that we do not only get a new - * result for the sink, but one for each of the parent nodes on the way. - */ - private def resultsForHeldTask(heldTask: ReachableByTask): List[(TaskFingerprint, TableEntry)] = - // Create a flat list of results by computing results for each - // table entry and appending them. - resultTable.get(heldTask.fingerprint) match - case Some(results) => - results - .flatMap { r => - createResultsForHeldTaskAndTableResult(heldTask, r) - } - case None => List() + /** This method creates a list of results from a held task and a table entry by appending paths of + * the held task to the path stored in the held task (`initialPath`) up to each of its parent + * tasks. + * + * A possible optimization here is to store computed slices in a lazily populated table and + * attempt to look them up. + */ + private def createResultsForHeldTaskAndTableResult( + heldTask: ReachableByTask, + result: TableEntry + ): List[(TaskFingerprint, TableEntry)] = + val parentTasks = heldTask.taskStack.dropRight(1) + val initialPath = heldTask.initialPath + parentTasks + .map { parentTask => + val stopIndex = initialPath + .map(x => (x.node, x.callSiteStack)) + .indexOf((parentTask.sink, parentTask.callSiteStack)) + 1 + val initialPathOnlyUpToSink = initialPath.slice(0, stopIndex) + val newPath = result.path ++ initialPathOnlyUpToSink + (parentTask, TableEntry(newPath)) + } + .filter { case (_, tableEntry) => containsCycle(tableEntry) } - /** This method creates a list of results from a held task and a table entry by appending paths - * of the held task to the path stored in the held task (`initialPath`) up to each of its - * parent tasks. - * - * A possible optimization here is to store computed slices in a lazily populated table and - * attempt to look them up. - */ - private def createResultsForHeldTaskAndTableResult( - heldTask: ReachableByTask, - result: TableEntry - ): List[(TaskFingerprint, TableEntry)] = - val parentTasks = heldTask.taskStack.dropRight(1) - val initialPath = heldTask.initialPath - parentTasks - .map { parentTask => - val stopIndex = initialPath - .map(x => (x.node, x.callSiteStack)) - .indexOf((parentTask.sink, parentTask.callSiteStack)) + 1 - val initialPathOnlyUpToSink = initialPath.slice(0, stopIndex) - val newPath = result.path ++ initialPathOnlyUpToSink - (parentTask, TableEntry(newPath)) - } - .filter { case (_, tableEntry) => containsCycle(tableEntry) } + private def containsCycle(tableEntry: TableEntry): Boolean = + val pathSeq = + tableEntry.path.map(x => (x.node, x.callSiteStack, x.isOutputArg, x.outEdgeLabel)) + pathSeq.distinct.size == pathSeq.size - private def containsCycle(tableEntry: TableEntry): Boolean = - val pathSeq = - tableEntry.path.map(x => (x.node, x.callSiteStack, x.isOutputArg, x.outEdgeLabel)) - pathSeq.distinct.size == pathSeq.size + private def addCompletedTasksToMainTable(results: List[(TaskFingerprint, TableEntry)]): Unit = + results.groupBy(_._1).foreach { case (fingerprint, resultList) => + val entries = resultList.map(_._2) + val old = resultTable.getOrElse(fingerprint, Vector()).toList + resultTable.put(fingerprint, deduplicateTableEntries(old ++ entries)) + } - private def addCompletedTasksToMainTable(results: List[(TaskFingerprint, TableEntry)]): Unit = - results.groupBy(_._1).foreach { case (fingerprint, resultList) => - val entries = resultList.map(_._2) - val old = resultTable.getOrElse(fingerprint, Vector()).toList - resultTable.put(fingerprint, deduplicateTableEntries(old ++ entries)) - } + private def deduplicateResultTable(): Unit = + resultTable.keys.foreach { key => + val results = resultTable(key) + resultTable.put(key, deduplicateTableEntries(results)) + } - private def deduplicateResultTable(): Unit = - resultTable.keys.foreach { key => - val results = resultTable(key) - resultTable.put(key, deduplicateTableEntries(results)) - } - - /** This method deduplicates the list of entries stored in a table cell. - * - * We treat entries as the same if their start and end point are the same. Points are given by - * nodes in the graph, the `callSiteStack` and the `isOutputArg` flag. - * - * For a group of flows that we treat as the same, we select the flow with the maximum length. - * If there are multiple flows with maximum length, then we compute a string representation of - * the flows - taking into account all fields - * - and select the flow with maximum length that is smallest in terms of this string - * representation. - */ - private def deduplicateTableEntries(list: List[TableEntry]): List[TableEntry] = - list - .groupBy { result => - val head = - result.path.headOption.map(x => (x.node, x.callSiteStack, x.isOutputArg)).get - val last = - result.path.lastOption.map(x => (x.node, x.callSiteStack, x.isOutputArg)).get - (head, last) - } - .map { case (_, list) => - val lenIdPathPairs = list.map(x => (x.path.length, x)) - val withMaxLength = (lenIdPathPairs.sortBy(_._1).reverse match - case Nil => Nil - case h :: t => h :: t.takeWhile(y => y._1 == h._1) - ).map(_._2) + /** This method deduplicates the list of entries stored in a table cell. + * + * We treat entries as the same if their start and end point are the same. Points are given by + * nodes in the graph, the `callSiteStack` and the `isOutputArg` flag. + * + * For a group of flows that we treat as the same, we select the flow with the maximum length. If + * there are multiple flows with maximum length, then we compute a string representation of the + * flows - taking into account all fields + * - and select the flow with maximum length that is smallest in terms of this string + * representation. + */ + private def deduplicateTableEntries(list: List[TableEntry]): List[TableEntry] = + list + .groupBy { result => + val head = + result.path.headOption.map(x => (x.node, x.callSiteStack, x.isOutputArg)).get + val last = + result.path.lastOption.map(x => (x.node, x.callSiteStack, x.isOutputArg)).get + (head, last) + } + .map { case (_, list) => + val lenIdPathPairs = list.map(x => (x.path.length, x)) + val withMaxLength = (lenIdPathPairs.sortBy(_._1).reverse match + case Nil => Nil + case h :: t => h :: t.takeWhile(y => y._1 == h._1) + ).map(_._2) - if withMaxLength.length == 1 then - withMaxLength.head - else - withMaxLength.minBy { x => - x.path - .map(x => - ( - x.node.id, - x.callSiteStack.map(_.id), - x.visible, - x.isOutputArg, - x.outEdgeLabel - ).toString - ) - .mkString("-") - } - } - .toList + if withMaxLength.length == 1 then + withMaxLength.head + else + withMaxLength.minBy { x => + x.path + .map(x => + ( + x.node.id, + x.callSiteStack.map(_.id), + x.visible, + x.isOutputArg, + x.outEdgeLabel + ).toString + ) + .mkString("-") + } + } + .toList end HeldTaskCompletion diff --git a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/queryengine/SourcesToStartingPoints.scala b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/queryengine/SourcesToStartingPoints.scala index 2b87b7e8..1fda6782 100644 --- a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/queryengine/SourcesToStartingPoints.scala +++ b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/queryengine/SourcesToStartingPoints.scala @@ -17,38 +17,38 @@ case class StartingPointWithSource(startingPoint: CfgNode, source: StoredNode) object SourcesToStartingPoints: - private val log = LoggerFactory.getLogger(SourcesToStartingPoints.getClass) - - def sourceTravsToStartingPoints[NodeType](sourceTravs: IterableOnce[NodeType]*) - : List[StartingPointWithSource] = - val fjp = new ForkJoinPool(Runtime.getRuntime.availableProcessors() / 2) - try - fjp.invoke(new SourceTravsToStartingPointsTask(sourceTravs*)).distinct - catch - case e: RejectedExecutionException => - log.error("Unable to execute 'SourceTravsToStartingPoints` task", e); List() - finally - fjp.shutdown() + private val log = LoggerFactory.getLogger(SourcesToStartingPoints.getClass) + + def sourceTravsToStartingPoints[NodeType](sourceTravs: IterableOnce[NodeType]*) + : List[StartingPointWithSource] = + val fjp = new ForkJoinPool(Runtime.getRuntime.availableProcessors() / 2) + try + fjp.invoke(new SourceTravsToStartingPointsTask(sourceTravs*)).distinct + catch + case e: RejectedExecutionException => + log.error("Unable to execute 'SourceTravsToStartingPoints` task", e); List() + finally + fjp.shutdown() class SourceTravsToStartingPointsTask[NodeType](sourceTravs: IterableOnce[NodeType]*) extends RecursiveTask[List[StartingPointWithSource]]: - private val log = LoggerFactory.getLogger(this.getClass) - - override def compute(): List[StartingPointWithSource] = - val sources: List[StoredNode] = sourceTravs - .flatMap(_.iterator.toList) - .collect { case n: StoredNode => n } - .dedup - .toList - .sortBy(_.id) - val tasks = sources.map(src => (src, new SourceToStartingPoints(src).fork())) - tasks.flatMap { case (src, t: ForkJoinTask[List[CfgNode]]) => - Try(t.get()) match - case Failure(e) => - log.error("Unable to complete 'SourceToStartingPoints' task", e); List() - case Success(sources) => sources.map(s => StartingPointWithSource(s, src)) - } + private val log = LoggerFactory.getLogger(this.getClass) + + override def compute(): List[StartingPointWithSource] = + val sources: List[StoredNode] = sourceTravs + .flatMap(_.iterator.toList) + .collect { case n: StoredNode => n } + .dedup + .toList + .sortBy(_.id) + val tasks = sources.map(src => (src, new SourceToStartingPoints(src).fork())) + tasks.flatMap { case (src, t: ForkJoinTask[List[CfgNode]]) => + Try(t.get()) match + case Failure(e) => + log.error("Unable to complete 'SourceToStartingPoints' task", e); List() + case Success(sources) => sources.map(s => StartingPointWithSource(s, src)) + } end SourceTravsToStartingPointsTask /** The code below deals with member variables, and specifically with the situation where literals @@ -59,168 +59,168 @@ end SourceTravsToStartingPointsTask */ class SourceToStartingPoints(src: StoredNode) extends RecursiveTask[List[CfgNode]]: - private val cpg = Cpg(src.graph()) - - override def compute(): List[CfgNode] = sourceToStartingPoints(src) - - private def sourceToStartingPoints(src: StoredNode): List[CfgNode] = - src match - case methodReturn: MethodReturn => - methodReturn.method.callIn.l - case lit: Literal => - List(lit) ++ usages( - targetsToClassIdentifierPair(literalToInitializedMembers(lit)) - ) ++ globalFromLiteral(lit) - case member: Member => - usages(targetsToClassIdentifierPair(List(member))) - case x: Declaration => - List(x).collectAll[CfgNode].toList - case x: Identifier => - (withFieldAndIndexAccesses( - List(x).collectAll[CfgNode].toList ++ x.refsTo.collectAll[Local].flatMap( - sourceToStartingPoints - ) - ) ++ x.refsTo.capturedByMethodRef.referencedMethod.flatMap(m => - usagesForName(x.name, m) - )).flatMap { - case x: Call => sourceToStartingPoints(x) - case x => List(x) - } - case x: Call => - (x._receiverIn.l :+ x).collect { case y: CfgNode => y } - case x => List(x).collect { case y: CfgNode => y } - - private def withFieldAndIndexAccesses(nodes: List[CfgNode]): List[CfgNode] = - nodes.flatMap { - case identifier: Identifier => - List(identifier) ++ fieldAndIndexAccesses(identifier) - case x => List(x) - } - - private def fieldAndIndexAccesses(identifier: Identifier): List[CfgNode] = - identifier.method._identifierViaContainsOut - .nameExact(identifier.name) - .inCall - .collect { case c if isFieldAccess(c.name) => c } - .l - - private def usages(pairs: List[(TypeDecl, AstNode)]): List[CfgNode] = - pairs.flatMap { case (typeDecl, astNode) => - val nonConstructorMethods = methodsRecursively(typeDecl).iterator - .whereNot(_.nameExact( - Defines.StaticInitMethodName, - Defines.ConstructorMethodName, - "__init__" - )) - .l - - val usagesInSameClass = - nonConstructorMethods.flatMap { m => firstUsagesOf(astNode, m, typeDecl) } - - val usagesInOtherClasses = cpg.method.flatMap { m => - m.fieldAccess - .where(_.argument(1).isIdentifier.typeFullNameExact(typeDecl.fullName)) - .where { x => - astNode match - case identifier: Identifier => - x.argument(2).isFieldIdentifier.canonicalNameExact(identifier.name) - case fieldIdentifier: FieldIdentifier => - x.argument(2).isFieldIdentifier.canonicalNameExact( - fieldIdentifier.canonicalName - ) - case _ => Iterator.empty - } - .takeWhile(notLeftHandOfAssignment) - .headOption - }.l - usagesInSameClass ++ usagesInOtherClasses - } - - /** For given method, determine the first usage of the given expression. - */ - private def firstUsagesOf(astNode: AstNode, m: Method, typeDecl: TypeDecl): List[Expression] = - astNode match - case member: Member => - usagesForName(member.name, m) - case identifier: Identifier => - usagesForName(identifier.name, m) - case fieldIdentifier: FieldIdentifier => - val fieldIdentifiers = - m.ast.isFieldIdentifier.sortBy(x => (x.lineNumber, x.columnNumber)).l - fieldIdentifiers - .canonicalNameExact(fieldIdentifier.canonicalName) - .inFieldAccess - // TODO `isIdentifier` seems to limit us here - .where(_.argument(1).isIdentifier.or( - _.nameExact("this", "self"), - _.typeFullNameExact(typeDecl.fullName) - )) - .takeWhile(notLeftHandOfAssignment) - .l - case _ => List() - - private def usagesForName(name: String, m: Method): List[Expression] = - val identifiers = m.ast.isIdentifier.sortBy(x => (x.lineNumber, x.columnNumber)).l - val identifierUsages = identifiers.nameExact(name).takeWhile(notLeftHandOfAssignment).l - val fieldIdentifiers = m.ast.isFieldIdentifier.sortBy(x => (x.lineNumber, x.columnNumber)).l - val thisRefs = Seq("this", "self") ++ m.typeDecl.name.headOption.toList - val fieldAccessUsages = fieldIdentifiers.isFieldIdentifier - .canonicalNameExact(name) - .inFieldAccess - .where(_.argument(1).codeExact(thisRefs*)) - .takeWhile(notLeftHandOfAssignment) - .l - (identifierUsages ++ fieldAccessUsages).headOption.toList - - /** For a literal, determine if it is used in the initialization of any member variables. Return - * list of initialized members. An initialized member is either an identifier or a - * field-identifier. - */ - private def literalToInitializedMembers(lit: Literal): List[Expression] = - lit.inAssignment - .or( - _.method.nameExact( + private val cpg = Cpg(src.graph()) + + override def compute(): List[CfgNode] = sourceToStartingPoints(src) + + private def sourceToStartingPoints(src: StoredNode): List[CfgNode] = + src match + case methodReturn: MethodReturn => + methodReturn.method.callIn.l + case lit: Literal => + List(lit) ++ usages( + targetsToClassIdentifierPair(literalToInitializedMembers(lit)) + ) ++ globalFromLiteral(lit) + case member: Member => + usages(targetsToClassIdentifierPair(List(member))) + case x: Declaration => + List(x).collectAll[CfgNode].toList + case x: Identifier => + (withFieldAndIndexAccesses( + List(x).collectAll[CfgNode].toList ++ x.refsTo.collectAll[Local].flatMap( + sourceToStartingPoints + ) + ) ++ x.refsTo.capturedByMethodRef.referencedMethod.flatMap(m => + usagesForName(x.name, m) + )).flatMap { + case x: Call => sourceToStartingPoints(x) + case x => List(x) + } + case x: Call => + (x._receiverIn.l :+ x).collect { case y: CfgNode => y } + case x => List(x).collect { case y: CfgNode => y } + + private def withFieldAndIndexAccesses(nodes: List[CfgNode]): List[CfgNode] = + nodes.flatMap { + case identifier: Identifier => + List(identifier) ++ fieldAndIndexAccesses(identifier) + case x => List(x) + } + + private def fieldAndIndexAccesses(identifier: Identifier): List[CfgNode] = + identifier.method._identifierViaContainsOut + .nameExact(identifier.name) + .inCall + .collect { case c if isFieldAccess(c.name) => c } + .l + + private def usages(pairs: List[(TypeDecl, AstNode)]): List[CfgNode] = + pairs.flatMap { case (typeDecl, astNode) => + val nonConstructorMethods = methodsRecursively(typeDecl).iterator + .whereNot(_.nameExact( Defines.StaticInitMethodName, Defines.ConstructorMethodName, "__init__" - ), - // in language such as Python, where assignments for members can be directly under a type decl - _.method.typeDecl - ) - .target - .flatMap { - case identifier: Identifier - // If these are the same, then the parent method is the module-level type - if Option( - identifier.method.fullName - ) == identifier.method.typeDecl.fullName.headOption || - // If a member shares the name of the identifier then we consider this as a member - lit.method.typeDecl.member.name.toSet.contains(identifier.name) => - List(identifier) - case call: Call if call.name == Operators.fieldAccess => - call.ast.isFieldIdentifier.l - case _ => List[Expression]() - } - .l - - private def methodsRecursively(typeDecl: TypeDecl): List[Method] = - def methods(x: AstNode): List[Method] = - x match - case m: Method => m :: m.astMinusRoot.isMethod.flatMap(methods).l - case _ => List() - typeDecl.method.flatMap(methods).l - - private def isTargetInAssignment(identifier: Identifier): List[Identifier] = - identifier.start.argumentIndex(1).where(_.inAssignment).l - - private def notLeftHandOfAssignment(x: Expression): Boolean = - !(x.argumentIndex == 1 && x.inCall.exists(y => allAssignmentTypes.contains(y.name))) - - private def targetsToClassIdentifierPair(targets: List[AstNode]): List[(TypeDecl, AstNode)] = - targets.flatMap { - case expr: Expression => - expr.method.typeDecl.map { typeDecl => (typeDecl, expr) } - case member: Member => - member.typeDecl.map { typeDecl => (typeDecl, member) } - } + )) + .l + + val usagesInSameClass = + nonConstructorMethods.flatMap { m => firstUsagesOf(astNode, m, typeDecl) } + + val usagesInOtherClasses = cpg.method.flatMap { m => + m.fieldAccess + .where(_.argument(1).isIdentifier.typeFullNameExact(typeDecl.fullName)) + .where { x => + astNode match + case identifier: Identifier => + x.argument(2).isFieldIdentifier.canonicalNameExact(identifier.name) + case fieldIdentifier: FieldIdentifier => + x.argument(2).isFieldIdentifier.canonicalNameExact( + fieldIdentifier.canonicalName + ) + case _ => Iterator.empty + } + .takeWhile(notLeftHandOfAssignment) + .headOption + }.l + usagesInSameClass ++ usagesInOtherClasses + } + + /** For given method, determine the first usage of the given expression. + */ + private def firstUsagesOf(astNode: AstNode, m: Method, typeDecl: TypeDecl): List[Expression] = + astNode match + case member: Member => + usagesForName(member.name, m) + case identifier: Identifier => + usagesForName(identifier.name, m) + case fieldIdentifier: FieldIdentifier => + val fieldIdentifiers = + m.ast.isFieldIdentifier.sortBy(x => (x.lineNumber, x.columnNumber)).l + fieldIdentifiers + .canonicalNameExact(fieldIdentifier.canonicalName) + .inFieldAccess + // TODO `isIdentifier` seems to limit us here + .where(_.argument(1).isIdentifier.or( + _.nameExact("this", "self"), + _.typeFullNameExact(typeDecl.fullName) + )) + .takeWhile(notLeftHandOfAssignment) + .l + case _ => List() + + private def usagesForName(name: String, m: Method): List[Expression] = + val identifiers = m.ast.isIdentifier.sortBy(x => (x.lineNumber, x.columnNumber)).l + val identifierUsages = identifiers.nameExact(name).takeWhile(notLeftHandOfAssignment).l + val fieldIdentifiers = m.ast.isFieldIdentifier.sortBy(x => (x.lineNumber, x.columnNumber)).l + val thisRefs = Seq("this", "self") ++ m.typeDecl.name.headOption.toList + val fieldAccessUsages = fieldIdentifiers.isFieldIdentifier + .canonicalNameExact(name) + .inFieldAccess + .where(_.argument(1).codeExact(thisRefs*)) + .takeWhile(notLeftHandOfAssignment) + .l + (identifierUsages ++ fieldAccessUsages).headOption.toList + + /** For a literal, determine if it is used in the initialization of any member variables. Return + * list of initialized members. An initialized member is either an identifier or a + * field-identifier. + */ + private def literalToInitializedMembers(lit: Literal): List[Expression] = + lit.inAssignment + .or( + _.method.nameExact( + Defines.StaticInitMethodName, + Defines.ConstructorMethodName, + "__init__" + ), + // in language such as Python, where assignments for members can be directly under a type decl + _.method.typeDecl + ) + .target + .flatMap { + case identifier: Identifier + // If these are the same, then the parent method is the module-level type + if Option( + identifier.method.fullName + ) == identifier.method.typeDecl.fullName.headOption || + // If a member shares the name of the identifier then we consider this as a member + lit.method.typeDecl.member.name.toSet.contains(identifier.name) => + List(identifier) + case call: Call if call.name == Operators.fieldAccess => + call.ast.isFieldIdentifier.l + case _ => List[Expression]() + } + .l + + private def methodsRecursively(typeDecl: TypeDecl): List[Method] = + def methods(x: AstNode): List[Method] = + x match + case m: Method => m :: m.astMinusRoot.isMethod.flatMap(methods).l + case _ => List() + typeDecl.method.flatMap(methods).l + + private def isTargetInAssignment(identifier: Identifier): List[Identifier] = + identifier.start.argumentIndex(1).where(_.inAssignment).l + + private def notLeftHandOfAssignment(x: Expression): Boolean = + !(x.argumentIndex == 1 && x.inCall.exists(y => allAssignmentTypes.contains(y.name))) + + private def targetsToClassIdentifierPair(targets: List[AstNode]): List[(TypeDecl, AstNode)] = + targets.flatMap { + case expr: Expression => + expr.method.typeDecl.map { typeDecl => (typeDecl, expr) } + case member: Member => + member.typeDecl.map { typeDecl => (typeDecl, member) } + } end SourceToStartingPoints diff --git a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/queryengine/TaskCreator.scala b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/queryengine/TaskCreator.scala index 7902fe1e..fc483ab0 100644 --- a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/queryengine/TaskCreator.scala +++ b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/queryengine/TaskCreator.scala @@ -19,198 +19,198 @@ import org.slf4j.{Logger, LoggerFactory} */ class TaskCreator(context: EngineContext): - private val logger: Logger = LoggerFactory.getLogger(this.getClass) - - /** For a given list of results and sources, generate new tasks. - */ - def createFromResults(results: Vector[ReachableByResult]): Vector[ReachableByTask] = - val newTasks = tasksForParams(results) ++ tasksForUnresolvedOutArgs(results) - removeTasksWithLoopsAndTooHighCallDepth(newTasks) - - private def removeTasksWithLoopsAndTooHighCallDepth(tasks: Vector[ReachableByTask]) - : Vector[ReachableByTask] = - val tasksWithValidCallDepth = if context.config.maxCallDepth == -1 then - tasks - else - tasks.filter(_.callDepth <= context.config.maxCallDepth) - tasksWithValidCallDepth.filter { t => - t.taskStack.dedup.size == t.taskStack.size - } - - /** Create new tasks from all results that start in a parameter. In essence, we want to traverse - * to corresponding arguments of call sites, but we need to be careful here not to create - * unrealizable paths. We achieve this by holding a call stack in results. - * - * Case 1: we expanded into a callee that we identified on the way, e.g., a method `y = - * transform(x)`, and we have reached the parameter of that method (`transform`). Upon doing - * so, we recorded the call site that we expanded in `result.callSite`. We would now like to - * continue exploring from the corresponding argument at that call site only. - * - * Case 2: walking backward from the sink, we have only expanded into callers so far, that is, - * the call stack is empty. In this case, the next tasks need to explore each call site to the - * method. - */ - private def tasksForParams(results: Vector[ReachableByResult]): Vector[ReachableByTask] = - startsAtParameter(results).flatMap { result => - val param = result.path.head.node.asInstanceOf[MethodParameterIn] - result.callSiteStack match - case callSite :: tail => - // Case 1 - paramToArgs(param).filter(x => x.inCall.exists(c => c == callSite)).map { arg => - ReachableByTask( - result.taskStack :+ TaskFingerprint(arg, tail, result.callDepth - 1), - result.path - ) - } - case _ => - // Case 2 - paramToArgs(param).map { arg => - ReachableByTask( - result.taskStack :+ TaskFingerprint(arg, List(), result.callDepth + 1), - result.path + private val logger: Logger = LoggerFactory.getLogger(this.getClass) + + /** For a given list of results and sources, generate new tasks. + */ + def createFromResults(results: Vector[ReachableByResult]): Vector[ReachableByTask] = + val newTasks = tasksForParams(results) ++ tasksForUnresolvedOutArgs(results) + removeTasksWithLoopsAndTooHighCallDepth(newTasks) + + private def removeTasksWithLoopsAndTooHighCallDepth(tasks: Vector[ReachableByTask]) + : Vector[ReachableByTask] = + val tasksWithValidCallDepth = if context.config.maxCallDepth == -1 then + tasks + else + tasks.filter(_.callDepth <= context.config.maxCallDepth) + tasksWithValidCallDepth.filter { t => + t.taskStack.dedup.size == t.taskStack.size + } + + /** Create new tasks from all results that start in a parameter. In essence, we want to traverse + * to corresponding arguments of call sites, but we need to be careful here not to create + * unrealizable paths. We achieve this by holding a call stack in results. + * + * Case 1: we expanded into a callee that we identified on the way, e.g., a method `y = + * transform(x)`, and we have reached the parameter of that method (`transform`). Upon doing so, + * we recorded the call site that we expanded in `result.callSite`. We would now like to continue + * exploring from the corresponding argument at that call site only. + * + * Case 2: walking backward from the sink, we have only expanded into callers so far, that is, + * the call stack is empty. In this case, the next tasks need to explore each call site to the + * method. + */ + private def tasksForParams(results: Vector[ReachableByResult]): Vector[ReachableByTask] = + startsAtParameter(results).flatMap { result => + val param = result.path.head.node.asInstanceOf[MethodParameterIn] + result.callSiteStack match + case callSite :: tail => + // Case 1 + paramToArgs(param).filter(x => x.inCall.exists(c => c == callSite)).map { arg => + ReachableByTask( + result.taskStack :+ TaskFingerprint(arg, tail, result.callDepth - 1), + result.path + ) + } + case _ => + // Case 2 + paramToArgs(param).map { arg => + ReachableByTask( + result.taskStack :+ TaskFingerprint(arg, List(), result.callDepth + 1), + result.path + ) + } + } + + /** Returns only those results that start at a parameter node. + */ + private def startsAtParameter(results: Vector[ReachableByResult]) = + results.collect { + case r: ReachableByResult if r.path.head.node.isInstanceOf[MethodParameterIn] => r + } + + /** For a given parameter of a method, determine all corresponding arguments at all call sites to + * the method. + */ + private def paramToArgs(param: MethodParameterIn): List[Expression] = + val args = paramToArgsOfCallers(param) ++ paramToMethodRefCallReceivers(param) + if args.size > context.config.maxArgsToAllow then + logger.warn(s"Too many arguments for parameter: ${args.size}. Not expanding") + logger.warn("Method name: " + param.method.fullName) + List() + else + args + + private def paramToArgsOfCallers(param: MethodParameterIn): List[Expression] = + NoResolve + .getMethodCallsites(param.method) + .collectAll[Call] + .argument(param.index) + .l + + /** Expand to receiver objects of calls that reference the method of the parameter, e.g., if + * `param` is a parameter of `m`, return `foo` in `foo.bar(m)` TODO: I'm not sure whether + * `methodRef.methodFullNameExact(...)` uses an index. If not, then caching these lookups or + * keeping a map of all method names to their references may make sense. + */ + + private def paramToMethodRefCallReceivers(param: MethodParameterIn): List[Expression] = + new Cpg(param.graph()).methodRef.methodFullNameExact(param.method.fullName).inCall.argument( + 0 + ).l + + /** Create new tasks from all results that end in an output argument, including return arguments. + * In this case, we want to traverse to corresponding method output parameters and method return + * nodes respectively. + */ + private def tasksForUnresolvedOutArgs(results: Vector[ReachableByResult]) + : Vector[ReachableByTask] = + + val outArgsAndCalls = results + .map(x => (x, x.outputArgument, x.path, x.callDepth)) + .distinct + + val forCalls = outArgsAndCalls.flatMap { case (result, outArg, path, callDepth) => + outArg.toList.flatMap { + case call: Call => + val methodReturns = call.toList + .flatMap(x => NoResolve.getCalledMethods(x).methodReturn.map(y => (x, y))) + .iterator + + methodReturns.flatMap { case (call, methodReturn) => + val method = methodReturn.method + val returnStatements = + methodReturn._reachingDefIn.toList.collect { case r: Return => r } + if method.isExternal || method.start.isStub.nonEmpty then + val newPath = path + (call.receiver.l ++ call.argument.l).map { arg => + val taskStack = result.taskStack :+ TaskFingerprint( + arg, + result.callSiteStack, + callDepth ) - } - } - - /** Returns only those results that start at a parameter node. - */ - private def startsAtParameter(results: Vector[ReachableByResult]) = - results.collect { - case r: ReachableByResult if r.path.head.node.isInstanceOf[MethodParameterIn] => r - } - - /** For a given parameter of a method, determine all corresponding arguments at all call sites - * to the method. - */ - private def paramToArgs(param: MethodParameterIn): List[Expression] = - val args = paramToArgsOfCallers(param) ++ paramToMethodRefCallReceivers(param) - if args.size > context.config.maxArgsToAllow then - logger.warn(s"Too many arguments for parameter: ${args.size}. Not expanding") - logger.warn("Method name: " + param.method.fullName) - List() - else - args - - private def paramToArgsOfCallers(param: MethodParameterIn): List[Expression] = - NoResolve - .getMethodCallsites(param.method) - .collectAll[Call] - .argument(param.index) - .l - - /** Expand to receiver objects of calls that reference the method of the parameter, e.g., if - * `param` is a parameter of `m`, return `foo` in `foo.bar(m)` TODO: I'm not sure whether - * `methodRef.methodFullNameExact(...)` uses an index. If not, then caching these lookups or - * keeping a map of all method names to their references may make sense. - */ - - private def paramToMethodRefCallReceivers(param: MethodParameterIn): List[Expression] = - new Cpg(param.graph()).methodRef.methodFullNameExact(param.method.fullName).inCall.argument( - 0 - ).l - - /** Create new tasks from all results that end in an output argument, including return - * arguments. In this case, we want to traverse to corresponding method output parameters and - * method return nodes respectively. - */ - private def tasksForUnresolvedOutArgs(results: Vector[ReachableByResult]) - : Vector[ReachableByTask] = - - val outArgsAndCalls = results - .map(x => (x, x.outputArgument, x.path, x.callDepth)) - .distinct - - val forCalls = outArgsAndCalls.flatMap { case (result, outArg, path, callDepth) => - outArg.toList.flatMap { - case call: Call => - val methodReturns = call.toList - .flatMap(x => NoResolve.getCalledMethods(x).methodReturn.map(y => (x, y))) - .iterator - - methodReturns.flatMap { case (call, methodReturn) => - val method = methodReturn.method - val returnStatements = - methodReturn._reachingDefIn.toList.collect { case r: Return => r } - if method.isExternal || method.start.isStub.nonEmpty then - val newPath = path - (call.receiver.l ++ call.argument.l).map { arg => - val taskStack = result.taskStack :+ TaskFingerprint( - arg, - result.callSiteStack, - callDepth - ) - ReachableByTask(taskStack, newPath) - } - else - returnStatements.map { returnStatement => - val newPath = - Vector(PathElement(methodReturn, result.callSiteStack)) ++ path - val taskStack = - result.taskStack :+ TaskFingerprint( - returnStatement, - call :: result.callSiteStack, - callDepth + 1 - ) - ReachableByTask(taskStack, newPath) - } - end if - } - case _ => Vector.empty - } - } - - val forArgs = outArgsAndCalls.flatMap { case (result, args, path, callDepth) => - args.toList.flatMap { - case arg: Expression => - val outParams = if result.callSiteStack.nonEmpty then - List[MethodParameterOut]() + ReachableByTask(taskStack, newPath) + } else - argToOutputParams(arg).l - outParams - .filterNot(_.method.isExternal) - .map { p => - val newStack = - arg.inCall.headOption.map { x => - x :: result.callSiteStack - }.getOrElse(result.callSiteStack) - ReachableByTask( - result.taskStack :+ TaskFingerprint(p, newStack, callDepth + 1), - path + returnStatements.map { returnStatement => + val newPath = + Vector(PathElement(methodReturn, result.callSiteStack)) ++ path + val taskStack = + result.taskStack :+ TaskFingerprint( + returnStatement, + call :: result.callSiteStack, + callDepth + 1 ) - } - case _ => Vector.empty - } + ReachableByTask(taskStack, newPath) + } + end if + } + case _ => Vector.empty } - - val forMethodRefs = outArgsAndCalls.flatMap { case (result, outArg, path, callDepth) => - outArg.toList.flatMap { - case methodRef: MethodRef => - val methodReturns = methodRef._refOut.collectAll[Method].methodReturn - methodReturns.flatMap { methodReturn => - val returnStatements = - methodReturn._reachingDefIn.toList.collect { case r: Return => r } - returnStatements.map { returnStatement => - val newPath = - Vector(PathElement(methodReturn, result.callSiteStack)) ++ path - val taskStack = - result.taskStack :+ TaskFingerprint( - returnStatement, - result.callSiteStack, - callDepth + 1 - ) - ReachableByTask(taskStack, newPath) - } + } + + val forArgs = outArgsAndCalls.flatMap { case (result, args, path, callDepth) => + args.toList.flatMap { + case arg: Expression => + val outParams = if result.callSiteStack.nonEmpty then + List[MethodParameterOut]() + else + argToOutputParams(arg).l + outParams + .filterNot(_.method.isExternal) + .map { p => + val newStack = + arg.inCall.headOption.map { x => + x :: result.callSiteStack + }.getOrElse(result.callSiteStack) + ReachableByTask( + result.taskStack :+ TaskFingerprint(p, newStack, callDepth + 1), + path + ) } - case _ => Vector.empty - } + case _ => Vector.empty + } + } + + val forMethodRefs = outArgsAndCalls.flatMap { case (result, outArg, path, callDepth) => + outArg.toList.flatMap { + case methodRef: MethodRef => + val methodReturns = methodRef._refOut.collectAll[Method].methodReturn + methodReturns.flatMap { methodReturn => + val returnStatements = + methodReturn._reachingDefIn.toList.collect { case r: Return => r } + returnStatements.map { returnStatement => + val newPath = + Vector(PathElement(methodReturn, result.callSiteStack)) ++ path + val taskStack = + result.taskStack :+ TaskFingerprint( + returnStatement, + result.callSiteStack, + callDepth + 1 + ) + ReachableByTask(taskStack, newPath) + } + } + case _ => Vector.empty } - restrictSize(forCalls) ++ restrictSize(forArgs) ++ restrictSize(forMethodRefs) - end tasksForUnresolvedOutArgs - - private def restrictSize(l: Vector[ReachableByTask]): Vector[ReachableByTask] = - if l.size <= context.config.maxOutputArgsExpansion then - l - else - logger.warn("Too many new tasks in expansion of unresolved output arguments") - Vector() + } + restrictSize(forCalls) ++ restrictSize(forArgs) ++ restrictSize(forMethodRefs) + end tasksForUnresolvedOutArgs + + private def restrictSize(l: Vector[ReachableByTask]): Vector[ReachableByTask] = + if l.size <= context.config.maxOutputArgsExpansion then + l + else + logger.warn("Too many new tasks in expansion of unresolved output arguments") + Vector() end TaskCreator diff --git a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/queryengine/TaskSolver.scala b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/queryengine/TaskSolver.scala index f3b366c8..3c243bae 100644 --- a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/queryengine/TaskSolver.scala +++ b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/queryengine/TaskSolver.scala @@ -27,213 +27,212 @@ import scala.collection.mutable class TaskSolver(task: ReachableByTask, context: EngineContext, sources: Set[CfgNode]) extends Callable[TaskSummary]: - import Engine.* - - /** Entry point of callable. First checks if the maximum call depth has been exceeded, in which - * case an empty result list is returned. Otherwise, the task is solved and its results are - * returned. + import Engine.* + + /** Entry point of callable. First checks if the maximum call depth has been exceeded, in which + * case an empty result list is returned. Otherwise, the task is solved and its results are + * returned. + */ + override def call(): TaskSummary = + implicit val sem: Semantics = context.semantics + val path = Vector(PathElement(task.sink, task.callSiteStack)) + val table: mutable.Map[TaskFingerprint, Vector[ReachableByResult]] = mutable.Map() + results(task.sink, path, table, task.callSiteStack) + // TODO why do we update the call depth here? + val finalResults = table.get(task.fingerprint).get.map { r => + r.copy( + taskStack = + r.taskStack.dropRight(1) :+ r.fingerprint.copy(callDepth = task.callDepth), + path = r.path ++ task.initialPath + ) + } + val (partial, complete) = finalResults.partition(_.partial) + val newTasks = new TaskCreator(context).createFromResults(partial) + TaskSummary(complete.flatMap(r => resultToTableEntries(r)), newTasks) + + private def resultToTableEntries(r: ReachableByResult): List[(TaskFingerprint, TableEntry)] = + r.taskStack.indices.map { i => + val parentTask = r.taskStack(i) + val pathToSink = r.path.slice(0, r.path.map(_.node).indexOf(parentTask.sink)) + val newPath = pathToSink :+ PathElement(parentTask.sink, parentTask.callSiteStack) + (parentTask, TableEntry(path = newPath)) + }.toList + + /** Recursively expand the DDG backwards and return a list of all results, given by at least a + * source node in `sourceSymbols` and the path between the source symbol and the sink. + * + * This method stays within the method (intra-procedural analysis) and terminates at method + * parameters and at output arguments. + * + * @param path + * This is a path from a node to the sink. The first node of the path is expanded by this + * method + * + * @param table + * The result table is a cache of known results that we can re-use + * + * @param callSiteStack + * This stack holds all call sites we expanded to arrive at the generation of the current task + */ + private def results[NodeType <: CfgNode]( + sink: CfgNode, + path: Vector[PathElement], + table: mutable.Map[TaskFingerprint, Vector[ReachableByResult]], + callSiteStack: List[Call] + )(implicit semantics: Semantics): Vector[ReachableByResult] = + + val curNode = path.head.node + + /** For each parent of the current node, determined via `expandIn`, check if results are + * available in the result table. If not, determine results recursively. */ - override def call(): TaskSummary = - implicit val sem: Semantics = context.semantics - val path = Vector(PathElement(task.sink, task.callSiteStack)) - val table: mutable.Map[TaskFingerprint, Vector[ReachableByResult]] = mutable.Map() - results(task.sink, path, table, task.callSiteStack) - // TODO why do we update the call depth here? - val finalResults = table.get(task.fingerprint).get.map { r => - r.copy( - taskStack = - r.taskStack.dropRight(1) :+ r.fingerprint.copy(callDepth = task.callDepth), - path = r.path ++ task.initialPath - ) - } - val (partial, complete) = finalResults.partition(_.partial) - val newTasks = new TaskCreator(context).createFromResults(partial) - TaskSummary(complete.flatMap(r => resultToTableEntries(r)), newTasks) - - private def resultToTableEntries(r: ReachableByResult): List[(TaskFingerprint, TableEntry)] = - r.taskStack.indices.map { i => - val parentTask = r.taskStack(i) - val pathToSink = r.path.slice(0, r.path.map(_.node).indexOf(parentTask.sink)) - val newPath = pathToSink :+ PathElement(parentTask.sink, parentTask.callSiteStack) - (parentTask, TableEntry(path = newPath)) - }.toList - - /** Recursively expand the DDG backwards and return a list of all results, given by at least a - * source node in `sourceSymbols` and the path between the source symbol and the sink. - * - * This method stays within the method (intra-procedural analysis) and terminates at method - * parameters and at output arguments. - * - * @param path - * This is a path from a node to the sink. The first node of the path is expanded by this - * method - * - * @param table - * The result table is a cache of known results that we can re-use - * - * @param callSiteStack - * This stack holds all call sites we expanded to arrive at the generation of the current - * task + def computeResultsForParents() = + deduplicateWithinTask(expandIn( + curNode.asInstanceOf[CfgNode], + path, + callSiteStack + ).iterator.flatMap { parent => + createResultsFromCacheOrCompute(parent, path) + }.toVector) + + def deduplicateWithinTask(vec: Vector[ReachableByResult]): Vector[ReachableByResult] = + vec + .groupBy { result => + val head = result.path.headOption.map(x => + (x.node, x.callSiteStack, x.isOutputArg) + ).get + val last = result.path.lastOption.map(x => + (x.node, x.callSiteStack, x.isOutputArg) + ).get + (head, last, result.partial, result.callDepth) + } + .map { case (_, list) => + val lenIdPathPairs = list.map(x => (x.path.length, x)).toList + val withMaxLength = (lenIdPathPairs.sortBy(_._1).reverse match + case Nil => Nil + case h :: t => h :: t.takeWhile(y => y._1 == h._1) + ).map(_._2) + + if withMaxLength.length == 1 then + withMaxLength.head + else + withMaxLength.minBy { x => + x.callDepth.toString + " " + + x.taskStack + .map(x => + x.sink.id.toString + ":" + x.callSiteStack.map( + _.id + ).mkString("|") + ) + .toString + " " + x.path + .map(x => + ( + x.node.id, + x.callSiteStack.map(_.id), + x.visible, + x.isOutputArg, + x.outEdgeLabel + ).toString + ) + .mkString("-") + } + end if + } + .toVector + + def createResultsFromCacheOrCompute(elemToPrepend: PathElement, path: Vector[PathElement]) = + val cachedResult = + createFromTable(table, elemToPrepend, task.callSiteStack, path, task.callDepth) + if cachedResult.isDefined then + QueryEngineStatistics.incrementBy(PATH_CACHE_HITS, 1L) + cachedResult.get + else + QueryEngineStatistics.incrementBy(PATH_CACHE_MISSES, 1L) + val newPath = elemToPrepend +: path + results(sink, newPath, table, callSiteStack) + + /** For a given path, determine whether results for the first element (`first`) are stored in + * the table, and if so, for each result, determine the path up to `first` and prepend it to + * `path`, giving us new results via table lookup. */ - private def results[NodeType <: CfgNode]( - sink: CfgNode, - path: Vector[PathElement], + def createFromTable( table: mutable.Map[TaskFingerprint, Vector[ReachableByResult]], - callSiteStack: List[Call] - )(implicit semantics: Semantics): Vector[ReachableByResult] = - - val curNode = path.head.node - - /** For each parent of the current node, determined via `expandIn`, check if results are - * available in the result table. If not, determine results recursively. - */ - def computeResultsForParents() = - deduplicateWithinTask(expandIn( - curNode.asInstanceOf[CfgNode], - path, - callSiteStack - ).iterator.flatMap { parent => - createResultsFromCacheOrCompute(parent, path) - }.toVector) - - def deduplicateWithinTask(vec: Vector[ReachableByResult]): Vector[ReachableByResult] = - vec - .groupBy { result => - val head = result.path.headOption.map(x => - (x.node, x.callSiteStack, x.isOutputArg) - ).get - val last = result.path.lastOption.map(x => - (x.node, x.callSiteStack, x.isOutputArg) - ).get - (head, last, result.partial, result.callDepth) - } - .map { case (_, list) => - val lenIdPathPairs = list.map(x => (x.path.length, x)).toList - val withMaxLength = (lenIdPathPairs.sortBy(_._1).reverse match - case Nil => Nil - case h :: t => h :: t.takeWhile(y => y._1 == h._1) - ).map(_._2) - - if withMaxLength.length == 1 then - withMaxLength.head - else - withMaxLength.minBy { x => - x.callDepth.toString + " " + - x.taskStack - .map(x => - x.sink.id.toString + ":" + x.callSiteStack.map( - _.id - ).mkString("|") - ) - .toString + " " + x.path - .map(x => - ( - x.node.id, - x.callSiteStack.map(_.id), - x.visible, - x.isOutputArg, - x.outEdgeLabel - ).toString - ) - .mkString("-") - } - end if - } - .toVector - - def createResultsFromCacheOrCompute(elemToPrepend: PathElement, path: Vector[PathElement]) = - val cachedResult = - createFromTable(table, elemToPrepend, task.callSiteStack, path, task.callDepth) - if cachedResult.isDefined then - QueryEngineStatistics.incrementBy(PATH_CACHE_HITS, 1L) - cachedResult.get - else - QueryEngineStatistics.incrementBy(PATH_CACHE_MISSES, 1L) - val newPath = elemToPrepend +: path - results(sink, newPath, table, callSiteStack) - - /** For a given path, determine whether results for the first element (`first`) are stored - * in the table, and if so, for each result, determine the path up to `first` and prepend - * it to `path`, giving us new results via table lookup. - */ - def createFromTable( - table: mutable.Map[TaskFingerprint, Vector[ReachableByResult]], - first: PathElement, - callSiteStack: List[Call], - remainder: Vector[PathElement], - callDepth: Int - ): Option[Vector[ReachableByResult]] = - table.get( - TaskFingerprint(first.node.asInstanceOf[CfgNode], callSiteStack, callDepth) - ).map { res => - res.map { r => - val stopIndex = r.path.map(x => (x.node, x.callSiteStack)).indexOf(( - first.node, - first.callSiteStack - )) - val pathToFirstNode = r.path.slice(0, stopIndex) - val completePath = pathToFirstNode ++ (first +: remainder) - r.copy(path = Vector(completePath.head) ++ completePath.tail) - } + first: PathElement, + callSiteStack: List[Call], + remainder: Vector[PathElement], + callDepth: Int + ): Option[Vector[ReachableByResult]] = + table.get( + TaskFingerprint(first.node.asInstanceOf[CfgNode], callSiteStack, callDepth) + ).map { res => + res.map { r => + val stopIndex = r.path.map(x => (x.node, x.callSiteStack)).indexOf(( + first.node, + first.callSiteStack + )) + val pathToFirstNode = r.path.slice(0, stopIndex) + val completePath = pathToFirstNode ++ (first +: remainder) + r.copy(path = Vector(completePath.head) ++ completePath.tail) } + } + + def createPartialResultForOutputArgOrRet() = + Vector( + ReachableByResult( + task.taskStack, + PathElement(path.head.node, callSiteStack, isOutputArg = true) +: path.tail, + partial = true + ) + ) - def createPartialResultForOutputArgOrRet() = + /** Determine results for the current node + */ + val res = curNode match + // Case 1: we have reached a source => return result and continue traversing (expand into parents) + case x if sources.contains(x.asInstanceOf[NodeType]) => + if x.isInstanceOf[MethodParameterIn] then Vector( - ReachableByResult( - task.taskStack, - PathElement(path.head.node, callSiteStack, isOutputArg = true) +: path.tail, - partial = true - ) - ) - - /** Determine results for the current node - */ - val res = curNode match - // Case 1: we have reached a source => return result and continue traversing (expand into parents) - case x if sources.contains(x.asInstanceOf[NodeType]) => - if x.isInstanceOf[MethodParameterIn] then - Vector( - ReachableByResult(task.taskStack, path), - ReachableByResult(task.taskStack, path, partial = true) - ) ++ computeResultsForParents() - else - Vector(ReachableByResult(task.taskStack, path)) ++ computeResultsForParents() - // Case 2: we have reached a method parameter (that isn't a source) => return partial result and stop traversing - case _: MethodParameterIn => - Vector(ReachableByResult(task.taskStack, path, partial = true)) - // Case 3: we have reached a call to an internal method without semantic (return value) and - // this isn't the start node => return partial result and stop traversing - case call: Call - if isCallToInternalMethodWithoutSemantic(call) - && !isArgOrRetOfMethodWeCameFrom(call, path) => - createPartialResultForOutputArgOrRet() - - // Case 4: we have reached an argument to an internal method without semantic (output argument) and - // this isn't the start node nor is it the argument for the parameter we just expanded => return partial result and stop traversing - case arg: Expression - if path.size > 1 - && arg.inCall.toList.exists(c => isCallToInternalMethodWithoutSemantic(c)) - && !arg.inCall.headOption.exists(x => isArgOrRetOfMethodWeCameFrom(x, path)) => - createPartialResultForOutputArgOrRet() - - case _: MethodRef => createPartialResultForOutputArgOrRet() - - // All other cases: expand into parents - case _ => - computeResultsForParents() - val key = TaskFingerprint(curNode.asInstanceOf[CfgNode], task.callSiteStack, task.callDepth) - table.updateWith(key) { - case Some(existingValue) => Some(existingValue ++ res) - case None => Some(res) - } - res - end results - - private def isArgOrRetOfMethodWeCameFrom(call: Call, path: Vector[PathElement]): Boolean = - path match - case Vector(_, PathElement(x: MethodReturn, _, _, _, _), _*) => - methodsForCall(call).contains(x.method) - case Vector(_, PathElement(x: MethodParameterIn, _, _, _, _), _*) => - methodsForCall(call).contains(x.method) - case _ => false + ReachableByResult(task.taskStack, path), + ReachableByResult(task.taskStack, path, partial = true) + ) ++ computeResultsForParents() + else + Vector(ReachableByResult(task.taskStack, path)) ++ computeResultsForParents() + // Case 2: we have reached a method parameter (that isn't a source) => return partial result and stop traversing + case _: MethodParameterIn => + Vector(ReachableByResult(task.taskStack, path, partial = true)) + // Case 3: we have reached a call to an internal method without semantic (return value) and + // this isn't the start node => return partial result and stop traversing + case call: Call + if isCallToInternalMethodWithoutSemantic(call) + && !isArgOrRetOfMethodWeCameFrom(call, path) => + createPartialResultForOutputArgOrRet() + + // Case 4: we have reached an argument to an internal method without semantic (output argument) and + // this isn't the start node nor is it the argument for the parameter we just expanded => return partial result and stop traversing + case arg: Expression + if path.size > 1 + && arg.inCall.toList.exists(c => isCallToInternalMethodWithoutSemantic(c)) + && !arg.inCall.headOption.exists(x => isArgOrRetOfMethodWeCameFrom(x, path)) => + createPartialResultForOutputArgOrRet() + + case _: MethodRef => createPartialResultForOutputArgOrRet() + + // All other cases: expand into parents + case _ => + computeResultsForParents() + val key = TaskFingerprint(curNode.asInstanceOf[CfgNode], task.callSiteStack, task.callDepth) + table.updateWith(key) { + case Some(existingValue) => Some(existingValue ++ res) + case None => Some(res) + } + res + end results + + private def isArgOrRetOfMethodWeCameFrom(call: Call, path: Vector[PathElement]): Boolean = + path match + case Vector(_, PathElement(x: MethodReturn, _, _, _, _), _*) => + methodsForCall(call).contains(x.method) + case Vector(_, PathElement(x: MethodParameterIn, _, _, _, _), _*) => + methodsForCall(call).contains(x.method) + case _ => false end TaskSolver diff --git a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/queryengine/package.scala b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/queryengine/package.scala index 646a96cd..4cd292df 100644 --- a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/queryengine/package.scala +++ b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/queryengine/package.scala @@ -4,109 +4,109 @@ import io.shiftleft.codepropertygraph.generated.nodes.{AstNode, Call, CfgNode} package object queryengine: - /** The TaskFingerprint uniquely identifies a task. - */ - case class TaskFingerprint(sink: CfgNode, callSiteStack: List[Call], callDepth: Int) - - /** A (partial) result, informing about a path that exists from a source to another node in the - * graph. - * - * @param taskStack - * The list of tasks that was solved to arrive at this task - * - * @param path - * A path to the sink. - * - * @param partial - * indicate whether this result stands on its own or requires further analysis, e.g., by - * expanding output arguments backwards into method output parameters. - */ - case class ReachableByResult( - taskStack: List[TaskFingerprint], - path: Vector[PathElement], - partial: Boolean = false - ): + /** The TaskFingerprint uniquely identifies a task. + */ + case class TaskFingerprint(sink: CfgNode, callSiteStack: List[Call], callDepth: Int) - def fingerprint: TaskFingerprint = taskStack.last - def sink: CfgNode = fingerprint.sink - def callSiteStack: List[Call] = fingerprint.callSiteStack + /** A (partial) result, informing about a path that exists from a source to another node in the + * graph. + * + * @param taskStack + * The list of tasks that was solved to arrive at this task + * + * @param path + * A path to the sink. + * + * @param partial + * indicate whether this result stands on its own or requires further analysis, e.g., by + * expanding output arguments backwards into method output parameters. + */ + case class ReachableByResult( + taskStack: List[TaskFingerprint], + path: Vector[PathElement], + partial: Boolean = false + ): - def callDepth: Int = fingerprint.callDepth + def fingerprint: TaskFingerprint = taskStack.last + def sink: CfgNode = fingerprint.sink + def callSiteStack: List[Call] = fingerprint.callSiteStack - def startingPoint: CfgNode = path.head.node.asInstanceOf[CfgNode] + def callDepth: Int = fingerprint.callDepth - /** If the result begins in an output argument, return it. - */ - def outputArgument: Option[CfgNode] = - path.headOption.collect { - case elem: PathElement if elem.isOutputArg => - elem.node.asInstanceOf[CfgNode] - } - end ReachableByResult + def startingPoint: CfgNode = path.head.node.asInstanceOf[CfgNode] - /** We represent data flows as sequences of path elements, where each path element consists of a - * node, flags and the label of its outgoing edge. - * - * @param node - * The parent node. This is actually always a CfgNode during data flow computation, however, - * since the source may be an arbitrary AST node, we may add an AST node to the start of the - * flow right before returning flows to the user. - * - * @param callSiteStack - * The call stack when this path element was created. Since we may enter the same function - * via two different call sites, path elements should only be treated as the same if they are - * the same node and we've reached them via the same call sequence. - * - * @param visible - * whether this path element should be shown in the flow - * @param isOutputArg - * input and output arguments are the same node in the CPG, so, we need this additional flag - * to determine whether we are on an input or output argument. By default, we consider - * arguments to be input arguments, meaning that when tracking `x` at `f(x)`, we do not - * expand into `f` but rather upwards to producers of `x`. - * @param outEdgeLabel - * label of the outgoing DDG edge + /** If the result begins in an output argument, return it. */ - case class PathElement( - node: AstNode, - callSiteStack: List[Call] = List(), - visible: Boolean = true, - isOutputArg: Boolean = false, - outEdgeLabel: String = "" - ) + def outputArgument: Option[CfgNode] = + path.headOption.collect { + case elem: PathElement if elem.isOutputArg => + elem.node.asInstanceOf[CfgNode] + } + end ReachableByResult - /** @param taskStack - * The list of tasks that was solved to arrive at this task, including the current task, - * which is to be solved. The current task is the last element of the list. - * - * @param initialPath - * The path from the current sink downwards to previous sinks. - */ - case class ReachableByTask(taskStack: List[TaskFingerprint], initialPath: Vector[PathElement]): + /** We represent data flows as sequences of path elements, where each path element consists of a + * node, flags and the label of its outgoing edge. + * + * @param node + * The parent node. This is actually always a CfgNode during data flow computation, however, + * since the source may be an arbitrary AST node, we may add an AST node to the start of the + * flow right before returning flows to the user. + * + * @param callSiteStack + * The call stack when this path element was created. Since we may enter the same function via + * two different call sites, path elements should only be treated as the same if they are the + * same node and we've reached them via the same call sequence. + * + * @param visible + * whether this path element should be shown in the flow + * @param isOutputArg + * input and output arguments are the same node in the CPG, so, we need this additional flag to + * determine whether we are on an input or output argument. By default, we consider arguments + * to be input arguments, meaning that when tracking `x` at `f(x)`, we do not expand into `f` + * but rather upwards to producers of `x`. + * @param outEdgeLabel + * label of the outgoing DDG edge + */ + case class PathElement( + node: AstNode, + callSiteStack: List[Call] = List(), + visible: Boolean = true, + isOutputArg: Boolean = false, + outEdgeLabel: String = "" + ) - /** This tasks fingerprint: if two tasks have the same fingerprint, then the TaskSolver MUST - * return the same result for them. This is the basis of our caching scheme. - */ - def fingerprint: TaskFingerprint = taskStack.last + /** @param taskStack + * The list of tasks that was solved to arrive at this task, including the current task, which + * is to be solved. The current task is the last element of the list. + * + * @param initialPath + * The path from the current sink downwards to previous sinks. + */ + case class ReachableByTask(taskStack: List[TaskFingerprint], initialPath: Vector[PathElement]): - /** The sink at which we start the analysis (upwards) - */ - def sink: CfgNode = fingerprint.sink + /** This tasks fingerprint: if two tasks have the same fingerprint, then the TaskSolver MUST + * return the same result for them. This is the basis of our caching scheme. + */ + def fingerprint: TaskFingerprint = taskStack.last - /** The call sites we have expanded downwards during this analysis. We need to keep track of - * this so that we do not end up expanding one call site and then returning to a different - * call site, which would produce an unreachable path. - */ - def callSiteStack: List[Call] = fingerprint.callSiteStack + /** The sink at which we start the analysis (upwards) + */ + def sink: CfgNode = fingerprint.sink + + /** The call sites we have expanded downwards during this analysis. We need to keep track of + * this so that we do not end up expanding one call site and then returning to a different call + * site, which would produce an unreachable path. + */ + def callSiteStack: List[Call] = fingerprint.callSiteStack - /** The call depth at which this task was created. - */ - def callDepth: Int = fingerprint.callDepth - end ReachableByTask + /** The call depth at which this task was created. + */ + def callDepth: Int = fingerprint.callDepth + end ReachableByTask - case class TaskSummary( - tableEntries: Vector[(TaskFingerprint, TableEntry)], - followupTasks: Vector[ReachableByTask] - ) - case class TableEntry(path: Vector[PathElement]) + case class TaskSummary( + tableEntries: Vector[(TaskFingerprint, TableEntry)], + followupTasks: Vector[ReachableByTask] + ) + case class TableEntry(path: Vector[PathElement]) end queryengine diff --git a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/semanticsloader/Parser.scala b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/semanticsloader/Parser.scala index 4e5a26d5..820da78b 100644 --- a/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/semanticsloader/Parser.scala +++ b/dataflowengineoss/src/main/scala/io/appthreat/dataflowengineoss/semanticsloader/Parser.scala @@ -11,54 +11,53 @@ import scala.jdk.CollectionConverters.* object Semantics: - def fromList(elements: List[FlowSemantic]): Semantics = - new Semantics( - mutable.Map.newBuilder - .addAll(elements.map { e => - e.methodFullName -> e - }) - .result() - ) + def fromList(elements: List[FlowSemantic]): Semantics = + new Semantics( + mutable.Map.newBuilder + .addAll(elements.map { e => + e.methodFullName -> e + }) + .result() + ) - def empty: Semantics = fromList(List()) + def empty: Semantics = fromList(List()) class Semantics private (methodToSemantic: mutable.Map[String, FlowSemantic]): - /** The map below keeps a mapping between results of a regex and the regex string it matches. - * e.g. - * - * `path/to/file.py:.Foo.sink` -> `^path.*Foo\\.sink$` - */ - private val regexMatchedFullNames = mutable.HashMap.empty[String, String] - - /** Initialize all the method semantics that use regex with all their regex results before query - * time. - */ - def loadRegexSemantics(cpg: Cpg): Unit = - import io.shiftleft.semanticcpg.language.* - - methodToSemantic.filter(_._2.regex).foreach { case (regexString, _) => - cpg.method.fullName(regexString).fullName.foreach { methodMatch => - regexMatchedFullNames.put(methodMatch, regexString) - } + /** The map below keeps a mapping between results of a regex and the regex string it matches. e.g. + * + * `path/to/file.py:.Foo.sink` -> `^path.*Foo\\.sink$` + */ + private val regexMatchedFullNames = mutable.HashMap.empty[String, String] + + /** Initialize all the method semantics that use regex with all their regex results before query + * time. + */ + def loadRegexSemantics(cpg: Cpg): Unit = + import io.shiftleft.semanticcpg.language.* + + methodToSemantic.filter(_._2.regex).foreach { case (regexString, _) => + cpg.method.fullName(regexString).fullName.foreach { methodMatch => + regexMatchedFullNames.put(methodMatch, regexString) } - - def elements: List[FlowSemantic] = methodToSemantic.values.toList - - def forMethod(fullName: String): Option[FlowSemantic] = - regexMatchedFullNames.get(fullName) match - case Some(matchedFullName) => methodToSemantic.get(matchedFullName) - case None => methodToSemantic.get(fullName) - - def serialize: String = - elements - .sortBy(_.methodFullName) - .map { elem => - s"\"${elem.methodFullName}\" " + elem.mappings - .collect { case FlowMapping(x, y) => s"$x -> $y" } - .mkString(" ") - } - .mkString("\n") + } + + def elements: List[FlowSemantic] = methodToSemantic.values.toList + + def forMethod(fullName: String): Option[FlowSemantic] = + regexMatchedFullNames.get(fullName) match + case Some(matchedFullName) => methodToSemantic.get(matchedFullName) + case None => methodToSemantic.get(fullName) + + def serialize: String = + elements + .sortBy(_.methodFullName) + .map { elem => + s"\"${elem.methodFullName}\" " + elem.mappings + .collect { case FlowMapping(x, y) => s"$x -> $y" } + .mkString(" ") + } + .mkString("\n") end Semantics case class FlowSemantic( @@ -69,19 +68,19 @@ case class FlowSemantic( object FlowSemantic: - def from(methodFullName: String, mappings: List[?], regex: Boolean = false): FlowSemantic = - FlowSemantic( - methodFullName, - mappings.map { - case (src: Int, dst: Int) => FlowMapping(src, dst) - case (srcIdx: Int, src: String, dst: Int) => FlowMapping(srcIdx, src, dst) - case (src: Int, dstIdx: Int, dst: String) => FlowMapping(src, dstIdx, dst) - case (srcIdx: Int, src: String, dstIdx: Int, dst: String) => - FlowMapping(srcIdx, src, dstIdx, dst) - case x: FlowMapping => x - }, - regex - ) + def from(methodFullName: String, mappings: List[?], regex: Boolean = false): FlowSemantic = + FlowSemantic( + methodFullName, + mappings.map { + case (src: Int, dst: Int) => FlowMapping(src, dst) + case (srcIdx: Int, src: String, dst: Int) => FlowMapping(srcIdx, src, dst) + case (src: Int, dstIdx: Int, dst: String) => FlowMapping(src, dstIdx, dst) + case (srcIdx: Int, src: String, dstIdx: Int, dst: String) => + FlowMapping(srcIdx, src, dstIdx, dst) + case x: FlowMapping => x + }, + regex + ) abstract class FlowNode @@ -91,12 +90,12 @@ abstract class FlowNode */ trait ParamOrRetNode extends FlowNode: - /** Temporary backward compatible idx field. - * - * @return - * the argument index. - */ - def index: Int + /** Temporary backward compatible idx field. + * + * @return + * the argument index. + */ + def index: Int /** A parameter where the index of the argument matches the position of the parameter at the callee. * The name is used to match named arguments if used instead of positional arguments. @@ -109,7 +108,7 @@ trait ParamOrRetNode extends FlowNode: case class ParameterNode(index: Int, name: Option[String] = None) extends ParamOrRetNode object ParameterNode: - def apply(index: Int, name: String): ParameterNode = ParameterNode(index, Option(name)) + def apply(index: Int, name: String): ParameterNode = ParameterNode(index, Option(name)) /** Represents explicit mappings or special cases. */ @@ -125,16 +124,16 @@ sealed trait FlowPath case class FlowMapping(src: FlowNode, dst: FlowNode) extends FlowPath object FlowMapping: - def apply(from: Int, to: Int): FlowMapping = FlowMapping(ParameterNode(from), ParameterNode(to)) + def apply(from: Int, to: Int): FlowMapping = FlowMapping(ParameterNode(from), ParameterNode(to)) - def apply(fromIdx: Int, from: String, toIdx: Int, to: String): FlowMapping = - FlowMapping(ParameterNode(fromIdx, from), ParameterNode(toIdx, to)) + def apply(fromIdx: Int, from: String, toIdx: Int, to: String): FlowMapping = + FlowMapping(ParameterNode(fromIdx, from), ParameterNode(toIdx, to)) - def apply(fromIdx: Int, from: String, toIdx: Int): FlowMapping = - FlowMapping(ParameterNode(fromIdx, from), ParameterNode(toIdx)) + def apply(fromIdx: Int, from: String, toIdx: Int): FlowMapping = + FlowMapping(ParameterNode(fromIdx, from), ParameterNode(toIdx)) - def apply(from: Int, toIdx: Int, to: String): FlowMapping = - FlowMapping(ParameterNode(from), ParameterNode(toIdx, to)) + def apply(from: Int, toIdx: Int, to: String): FlowMapping = + FlowMapping(ParameterNode(from), ParameterNode(toIdx, to)) /** Represents an instance where parameters are not sanitized, may affect the return value, and do * not cross-taint. e.g. foo(1, 2) = 1 -> 1, 2 -> 2, 1 -> -1, 2 -> -1 @@ -146,55 +145,55 @@ object PassThroughMapping extends FlowPath class Parser(): - def parse(input: String): List[FlowSemantic] = - val charStream = CharStreams.fromString(input) - parseCharStream(charStream) + def parse(input: String): List[FlowSemantic] = + val charStream = CharStreams.fromString(input) + parseCharStream(charStream) - def parseFile(fileName: String): List[FlowSemantic] = - val charStream = CharStreams.fromFileName(fileName) - parseCharStream(charStream) + def parseFile(fileName: String): List[FlowSemantic] = + val charStream = CharStreams.fromFileName(fileName) + parseCharStream(charStream) - private def parseCharStream(charStream: CharStream): List[FlowSemantic] = - val lexer = new SemanticsLexer(charStream) - val tokenStream = new CommonTokenStream(lexer) - val parser = new SemanticsParser(tokenStream) - val treeWalker = new ParseTreeWalker() + private def parseCharStream(charStream: CharStream): List[FlowSemantic] = + val lexer = new SemanticsLexer(charStream) + val tokenStream = new CommonTokenStream(lexer) + val parser = new SemanticsParser(tokenStream) + val treeWalker = new ParseTreeWalker() - val tree = parser.taintSemantics() - val listener = new Listener() - treeWalker.walk(listener, tree) - listener.result.toList + val tree = parser.taintSemantics() + val listener = new Listener() + treeWalker.walk(listener, tree) + listener.result.toList - implicit class AntlrFlowExtensions(val ctx: MappingContext): + implicit class AntlrFlowExtensions(val ctx: MappingContext): - def isPassThrough: Boolean = Option(ctx.PASSTHROUGH()).isDefined + def isPassThrough: Boolean = Option(ctx.PASSTHROUGH()).isDefined - def srcIdx: Int = ctx.src().argIdx().NUMBER().getText.toInt + def srcIdx: Int = ctx.src().argIdx().NUMBER().getText.toInt - def srcArgName: Option[String] = Option(ctx.src().argName()).map(_.name().getText) + def srcArgName: Option[String] = Option(ctx.src().argName()).map(_.name().getText) - def dstIdx: Int = ctx.dst().argIdx().NUMBER().getText.toInt + def dstIdx: Int = ctx.dst().argIdx().NUMBER().getText.toInt - def dstArgName: Option[String] = Option(ctx.dst().argName()).map(_.name().getText) + def dstArgName: Option[String] = Option(ctx.dst().argName()).map(_.name().getText) - private class Listener extends SemanticsBaseListener: + private class Listener extends SemanticsBaseListener: - val result: mutable.ListBuffer[FlowSemantic] = mutable.ListBuffer[FlowSemantic]() + val result: mutable.ListBuffer[FlowSemantic] = mutable.ListBuffer[FlowSemantic]() - override def enterTaintSemantics(ctx: SemanticsParser.TaintSemanticsContext): Unit = - ctx.singleSemantic().asScala.foreach { semantic => - val methodName = semantic.methodName().name().getText - val mappings = semantic.mapping().asScala.toList.map(ctxToParamMapping) - result.addOne(FlowSemantic(methodName, mappings)) - } + override def enterTaintSemantics(ctx: SemanticsParser.TaintSemanticsContext): Unit = + ctx.singleSemantic().asScala.foreach { semantic => + val methodName = semantic.methodName().name().getText + val mappings = semantic.mapping().asScala.toList.map(ctxToParamMapping) + result.addOne(FlowSemantic(methodName, mappings)) + } - private def ctxToParamMapping(ctx: MappingContext): FlowPath = - if ctx.isPassThrough then - PassThroughMapping - else - val src = ParameterNode(ctx.srcIdx, ctx.srcArgName) - val dst = ParameterNode(ctx.dstIdx, ctx.dstArgName) + private def ctxToParamMapping(ctx: MappingContext): FlowPath = + if ctx.isPassThrough then + PassThroughMapping + else + val src = ParameterNode(ctx.srcIdx, ctx.srcArgName) + val dst = ParameterNode(ctx.dstIdx, ctx.dstArgName) - FlowMapping(src, dst) - end Listener + FlowMapping(src, dst) + end Listener end Parser diff --git a/macros/src/main/scala/io/appthreat/console/Query.scala b/macros/src/main/scala/io/appthreat/console/Query.scala index 2e52d570..5c9369aa 100644 --- a/macros/src/main/scala/io/appthreat/console/Query.scala +++ b/macros/src/main/scala/io/appthreat/console/Query.scala @@ -25,29 +25,29 @@ case class Query( ) object Query: - def make( - name: String, - author: String, - title: String, - description: String, - score: Double, - traversalWithStrRep: TraversalWithStrRep, - tags: List[String] = List(), - codeExamples: CodeExamples = CodeExamples(List(), List()), - multiFileCodeExamples: MultiFileCodeExamples = MultiFileCodeExamples(List(), List()) - ): Query = - Query( - name = name, - author = author, - title = title, - description = description, - score = score, - traversal = traversalWithStrRep.traversal, - traversalAsString = traversalWithStrRep.strRep, - tags = tags, - codeExamples = codeExamples, - multiFileCodeExamples = multiFileCodeExamples - ) + def make( + name: String, + author: String, + title: String, + description: String, + score: Double, + traversalWithStrRep: TraversalWithStrRep, + tags: List[String] = List(), + codeExamples: CodeExamples = CodeExamples(List(), List()), + multiFileCodeExamples: MultiFileCodeExamples = MultiFileCodeExamples(List(), List()) + ): Query = + Query( + name = name, + author = author, + title = title, + description = description, + score = score, + traversal = traversalWithStrRep.traversal, + traversalAsString = traversalWithStrRep.strRep, + tags = tags, + codeExamples = codeExamples, + multiFileCodeExamples = multiFileCodeExamples + ) end Query case class TraversalWithStrRep(traversal: Cpg => Iterator[? <: StoredNode], strRep: String = "") diff --git a/macros/src/main/scala/io/appthreat/console/QueryDatabase.scala b/macros/src/main/scala/io/appthreat/console/QueryDatabase.scala index 9fbb420c..3653702c 100644 --- a/macros/src/main/scala/io/appthreat/console/QueryDatabase.scala +++ b/macros/src/main/scala/io/appthreat/console/QueryDatabase.scala @@ -14,59 +14,59 @@ class QueryDatabase( namespace: String = "io.appthreat.scanners" ): - /** Determine all bundles on the class path - */ - def allBundles: List[Class[? <: QueryBundle]] = - new Reflections( - new ConfigurationBuilder().setUrls( - ClasspathHelper.forPackage( - namespace, - ClasspathHelper.contextClassLoader(), - ClasspathHelper.staticClassLoader() - ) + /** Determine all bundles on the class path + */ + def allBundles: List[Class[? <: QueryBundle]] = + new Reflections( + new ConfigurationBuilder().setUrls( + ClasspathHelper.forPackage( + namespace, + ClasspathHelper.contextClassLoader(), + ClasspathHelper.staticClassLoader() ) - ).getSubTypesOf(classOf[QueryBundle]).asScala.toList + ) + ).getSubTypesOf(classOf[QueryBundle]).asScala.toList - /** Determine queries across all bundles - */ - def allQueries: List[Query] = - allBundles.flatMap { bundle => - queriesInBundle(bundle) - } + /** Determine queries across all bundles + */ + def allQueries: List[Query] = + allBundles.flatMap { bundle => + queriesInBundle(bundle) + } - /** Return all queries inside `bundle`. - */ - def queriesInBundle[T <: QueryBundle](bundle: Class[T]): List[Query] = - val instance = bundle.getField("MODULE$").get(null) - queryCreatorsInBundle(bundle).map { case (method, args) => - val query = method.invoke(instance, args*).asInstanceOf[Query] - val bundleNamespace = bundle.getPackageName - // the namespace currently looks like `io.appthreat.scanners.c.CopyLoops` - val namespaceParts = bundleNamespace.split('.') - val language = - if bundleNamespace.startsWith("io.appthreat.chen.scanners") then - namespaceParts(4) - else if namespaceParts.length > 3 then - namespaceParts(3) - else - "" - query.copy(language = language) - } + /** Return all queries inside `bundle`. + */ + def queriesInBundle[T <: QueryBundle](bundle: Class[T]): List[Query] = + val instance = bundle.getField("MODULE$").get(null) + queryCreatorsInBundle(bundle).map { case (method, args) => + val query = method.invoke(instance, args*).asInstanceOf[Query] + val bundleNamespace = bundle.getPackageName + // the namespace currently looks like `io.appthreat.scanners.c.CopyLoops` + val namespaceParts = bundleNamespace.split('.') + val language = + if bundleNamespace.startsWith("io.appthreat.chen.scanners") then + namespaceParts(4) + else if namespaceParts.length > 3 then + namespaceParts(3) + else + "" + query.copy(language = language) + } - /** Obtain all (method, args) pairs from bundle, making it possible to override default args - * before creating the query. - */ - def queryCreatorsInBundle[T <: QueryBundle](bundle: Class[T]): List[(Method, List[Any])] = - val methods = bundle.getMethods.filter(_.getAnnotations.exists(_.isInstanceOf[q])).toList - methods.map { method => - val args = defaultArgs(method, bundle) - (method, args) - } + /** Obtain all (method, args) pairs from bundle, making it possible to override default args + * before creating the query. + */ + def queryCreatorsInBundle[T <: QueryBundle](bundle: Class[T]): List[(Method, List[Any])] = + val methods = bundle.getMethods.filter(_.getAnnotations.exists(_.isInstanceOf[q])).toList + methods.map { method => + val args = defaultArgs(method, bundle) + (method, args) + } - private def defaultArgs[T <: QueryBundle](method: Method, bundle: Class[T]): List[Any] = - method.getParameters.zipWithIndex.map { case (parameter, index) => - defaultArgumentProvider.defaultArgument(method, bundle, parameter, index) - }.toList + private def defaultArgs[T <: QueryBundle](method: Method, bundle: Class[T]): List[Any] = + method.getParameters.zipWithIndex.map { case (parameter, index) => + defaultArgumentProvider.defaultArgument(method, bundle, parameter, index) + }.toList end QueryDatabase /** Joern and Ocular require different implicits to be present, and when we encounter these @@ -77,22 +77,22 @@ end QueryDatabase */ class DefaultArgumentProvider: - def typeSpecificDefaultArg(@unused argTypeFullName: String): Option[Any] = - None + def typeSpecificDefaultArg(@unused argTypeFullName: String): Option[Any] = + None - final def defaultArgument(method: Method, bundle: Class[?], parameter: Parameter, i: Int): Any = - val instance = bundle.getField("MODULE$").get(null) - val defaultArgOption = typeSpecificDefaultArg(parameter.getType.getTypeName) - defaultArgOption.getOrElse { - val defaultMethodName = s"${method.getName}$$default$$${i + 1}" - try - val defaultMethod = bundle.getDeclaredMethod(defaultMethodName) - val defaultValue = defaultMethod.invoke(instance) - defaultValue - catch - case e: NoSuchMethodException => - throw new RuntimeException( - s"No default value found for parameter `${parameter.toString}` of query creator method `$method` " - ) - } + final def defaultArgument(method: Method, bundle: Class[?], parameter: Parameter, i: Int): Any = + val instance = bundle.getField("MODULE$").get(null) + val defaultArgOption = typeSpecificDefaultArg(parameter.getType.getTypeName) + defaultArgOption.getOrElse { + val defaultMethodName = s"${method.getName}$$default$$${i + 1}" + try + val defaultMethod = bundle.getDeclaredMethod(defaultMethodName) + val defaultValue = defaultMethod.invoke(instance) + defaultValue + catch + case e: NoSuchMethodException => + throw new RuntimeException( + s"No default value found for parameter `${parameter.toString}` of query creator method `$method` " + ) + } end DefaultArgumentProvider diff --git a/macros/src/main/scala/io/appthreat/macros/QueryMacros.scala b/macros/src/main/scala/io/appthreat/macros/QueryMacros.scala index 716d23e5..96d485f4 100644 --- a/macros/src/main/scala/io/appthreat/macros/QueryMacros.scala +++ b/macros/src/main/scala/io/appthreat/macros/QueryMacros.scala @@ -8,13 +8,13 @@ import scala.quoted.{Expr, Quotes} object QueryMacros: - inline def withStrRep(inline traversal: Cpg => Iterator[? <: StoredNode]): TraversalWithStrRep = - ${ withStrRepImpl('{ traversal }) } + inline def withStrRep(inline traversal: Cpg => Iterator[? <: StoredNode]): TraversalWithStrRep = + ${ withStrRepImpl('{ traversal }) } - private def withStrRepImpl( - travExpr: Expr[Cpg => Iterator[? <: StoredNode]] - )(using quotes: Quotes): Expr[TraversalWithStrRep] = - import quotes.reflect.* - val pos = travExpr.asTerm.pos - val code = Position(pos.sourceFile, pos.start, pos.end).sourceCode.getOrElse("N/A") - '{ TraversalWithStrRep(${ travExpr }, ${ Expr(code) }) } + private def withStrRepImpl( + travExpr: Expr[Cpg => Iterator[? <: StoredNode]] + )(using quotes: Quotes): Expr[TraversalWithStrRep] = + import quotes.reflect.* + val pos = travExpr.asTerm.pos + val code = Position(pos.sourceFile, pos.start, pos.end).sourceCode.getOrElse("N/A") + '{ TraversalWithStrRep(${ travExpr }, ${ Expr(code) }) } diff --git a/macros/src/test/scala/io/appthreat/console/QueryDatabaseTests.scala b/macros/src/test/scala/io/appthreat/console/QueryDatabaseTests.scala index b715722b..aa9fcfcf 100644 --- a/macros/src/test/scala/io/appthreat/console/QueryDatabaseTests.scala +++ b/macros/src/test/scala/io/appthreat/console/QueryDatabaseTests.scala @@ -1,7 +1,7 @@ package io.appthreat.console import io.shiftleft.codepropertygraph.Cpg -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import org.scalatest.matchers.should import org.scalatest.wordspec.AnyWordSpec diff --git a/macros/src/test/scala/io/appthreat/macros/QueryMacroTests.scala b/macros/src/test/scala/io/appthreat/macros/QueryMacroTests.scala index 134dd107..4bf94666 100644 --- a/macros/src/test/scala/io/appthreat/macros/QueryMacroTests.scala +++ b/macros/src/test/scala/io/appthreat/macros/QueryMacroTests.scala @@ -4,8 +4,8 @@ import org.scalatest.matchers.should.Matchers import org.scalatest.wordspec.AnyWordSpec import QueryMacros.withStrRep -import io.appthreat.console._ -import io.shiftleft.semanticcpg.language._ +import io.appthreat.console.* +import io.shiftleft.semanticcpg.language.* class QueryMacroTests extends AnyWordSpec with Matchers { "Query macros" should { diff --git a/meta.yaml b/meta.yaml index 977e8f25..28d4ae61 100644 --- a/meta.yaml +++ b/meta.yaml @@ -1,4 +1,4 @@ -{% set version = "2.1.2" %} +{% set version = "2.1.3" %} package: name: chen diff --git a/platform/build.sbt b/platform/build.sbt index bd1115cd..6e5c09af 100644 --- a/platform/build.sbt +++ b/platform/build.sbt @@ -53,7 +53,7 @@ Universal / mappings += cpgVersionFile.value -> "schema-extender/cpg-version" lazy val generateScaladocs = taskKey[File]("generate scaladocs from combined project sources") generateScaladocs := { - import better.files._ + import better.files.* import java.io.{File => JFile, PrintWriter} import sbt.internal.inc.AnalyzingCompiler import sbt.internal.util.Attributed.data diff --git a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/C2Cpg.scala b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/C2Cpg.scala index 143e003b..d4833b71 100644 --- a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/C2Cpg.scala +++ b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/C2Cpg.scala @@ -18,18 +18,18 @@ import scala.util.Try class C2Cpg extends X2CpgFrontend[Config]: - private val report: Report = new Report() + private val report: Report = new Report() - def createCpg(config: Config): Try[Cpg] = - withNewEmptyCpg(config.outputPath, config) { (cpg, config) => - new MetaDataPass(cpg, Languages.NEWC, config.inputPath).createAndApply() - new AstCreationPass(cpg, config, report).createAndApply() - new ConfigFileCreationPass(cpg).createAndApply() - TypeNodePass.withRegisteredTypes(CGlobal.typesSeen(), cpg).createAndApply() - new TypeDeclNodePass(cpg)(config.schemaValidation).createAndApply() - report.print() - } + def createCpg(config: Config): Try[Cpg] = + withNewEmptyCpg(config.outputPath, config) { (cpg, config) => + new MetaDataPass(cpg, Languages.NEWC, config.inputPath).createAndApply() + new AstCreationPass(cpg, config, report).createAndApply() + new ConfigFileCreationPass(cpg).createAndApply() + TypeNodePass.withRegisteredTypes(CGlobal.typesSeen(), cpg).createAndApply() + new TypeDeclNodePass(cpg)(config.schemaValidation).createAndApply() + report.print() + } - def printIfDefsOnly(config: Config): Unit = - val stmts = new PreprocessorPass(config).run().mkString(",") - println(stmts) + def printIfDefsOnly(config: Config): Unit = + val stmts = new PreprocessorPass(config).run().mkString(",") + println(stmts) diff --git a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/Main.scala b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/Main.scala index 4a234820..825cc2a7 100644 --- a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/Main.scala +++ b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/Main.scala @@ -22,113 +22,113 @@ final case class Config( includeImageLocations: Boolean = false, useProjectIndex: Boolean = true ) extends X2CpgConfig[Config]: - def withIncludeFiles(includeFiles: Set[String]): Config = - this.copy(includeFiles = includeFiles).withInheritedFields(this) - def withIncludePaths(includePaths: Set[String]): Config = - this.copy(includePaths = includePaths).withInheritedFields(this) + def withIncludeFiles(includeFiles: Set[String]): Config = + this.copy(includeFiles = includeFiles).withInheritedFields(this) + def withIncludePaths(includePaths: Set[String]): Config = + this.copy(includePaths = includePaths).withInheritedFields(this) - def withMacroFiles(macroFiles: Set[String]): Config = - this.copy(macroFiles = macroFiles).withInheritedFields(this) - def withDefines(defines: Set[String]): Config = - this.copy(defines = defines).withInheritedFields(this) + def withMacroFiles(macroFiles: Set[String]): Config = + this.copy(macroFiles = macroFiles).withInheritedFields(this) + def withDefines(defines: Set[String]): Config = + this.copy(defines = defines).withInheritedFields(this) - def withIncludeComments(value: Boolean): Config = - this.copy(includeComments = value).withInheritedFields(this) + def withIncludeComments(value: Boolean): Config = + this.copy(includeComments = value).withInheritedFields(this) - def withLogProblems(value: Boolean): Config = - this.copy(logProblems = value).withInheritedFields(this) + def withLogProblems(value: Boolean): Config = + this.copy(logProblems = value).withInheritedFields(this) - def withLogPreprocessor(value: Boolean): Config = - this.copy(logPreprocessor = value).withInheritedFields(this) + def withLogPreprocessor(value: Boolean): Config = + this.copy(logPreprocessor = value).withInheritedFields(this) - def withPrintIfDefsOnly(value: Boolean): Config = - this.copy(printIfDefsOnly = value).withInheritedFields(this) + def withPrintIfDefsOnly(value: Boolean): Config = + this.copy(printIfDefsOnly = value).withInheritedFields(this) - def withIncludePathsAutoDiscovery(value: Boolean): Config = - this.copy(includePathsAutoDiscovery = value).withInheritedFields(this) + def withIncludePathsAutoDiscovery(value: Boolean): Config = + this.copy(includePathsAutoDiscovery = value).withInheritedFields(this) - def withFunctionBodies(value: Boolean): Config = - this.copy(includeFunctionBodies = value).withInheritedFields(this) + def withFunctionBodies(value: Boolean): Config = + this.copy(includeFunctionBodies = value).withInheritedFields(this) - def withImageLocations(value: Boolean): Config = - this.copy(includeImageLocations = value).withInheritedFields(this) + def withImageLocations(value: Boolean): Config = + this.copy(includeImageLocations = value).withInheritedFields(this) - def withProjectIndexes(value: Boolean): Config = - this.copy(useProjectIndex = value).withInheritedFields(this) + def withProjectIndexes(value: Boolean): Config = + this.copy(useProjectIndex = value).withInheritedFields(this) end Config private object Frontend: - implicit val defaultConfig: Config = Config() - - val cmdLineParser: OParser[Unit, Config] = - val builder = OParser.builder[Config] - import builder.* - OParser.sequence( - programName(classOf[C2Cpg].getSimpleName), - opt[Unit]("include-comments") - .text(s"includes all comments into the CPG") - .action((_, c) => c.withIncludeComments(true)), - opt[Unit]("log-problems") - .text(s"enables logging of all parse problems while generating the CPG") - .action((_, c) => c.withLogProblems(true)), - opt[Unit]("log-preprocessor") - .text(s"enables logging of all preprocessor statements while generating the CPG") - .action((_, c) => c.withLogPreprocessor(true)), - opt[Unit]("print-ifdef-only") - .text( - s"prints a comma-separated list of all preprocessor ifdef and if statements; does not create a CPG" - ) - .action((_, c) => c.withPrintIfDefsOnly(true)), - opt[String]("include") - .unbounded() - .text("header include paths") - .action((incl, c) => c.withIncludePaths(c.includePaths + incl)), - opt[String]("include-files") - .unbounded() - .text("header include files") - .action((inclf, c) => c.withIncludeFiles(c.includeFiles + inclf)), - opt[String]("macro-files") - .unbounded() - .text("macro files") - .action((macrof, c) => c.withMacroFiles(c.macroFiles + macrof)), - opt[Unit]("no-include-auto-discovery") - .text("disables auto discovery of system header include paths") - .hidden(), - opt[Unit]("with-include-auto-discovery") - .text("enables auto discovery of system header include paths") - .action((_, c) => c.withIncludePathsAutoDiscovery(true)), - opt[Unit]("with-function-bodies") - .text("instructs the parser to parse function and method bodies.") - .action((_, c) => c.withFunctionBodies(true)), - opt[Unit]("with-image-locations") - .text( - "allows the parser to create image-locations. An image location explains how a name made it into the translation unit. Eg: via macro expansion or preprocessor." - ) - .action((_, c) => c.withImageLocations(true)), - opt[Unit]("with-project-index") - .text( - "performance optimization, allows the parser to use an existing eclipse project(s) index(es)." - ) - .action((_, c) => c.withProjectIndexes(true)), - opt[String]("define") - .unbounded() - .text("define a name") - .action((d, c) => c.withDefines(c.defines + d)) - ) - end cmdLineParser + implicit val defaultConfig: Config = Config() + + val cmdLineParser: OParser[Unit, Config] = + val builder = OParser.builder[Config] + import builder.* + OParser.sequence( + programName(classOf[C2Cpg].getSimpleName), + opt[Unit]("include-comments") + .text(s"includes all comments into the CPG") + .action((_, c) => c.withIncludeComments(true)), + opt[Unit]("log-problems") + .text(s"enables logging of all parse problems while generating the CPG") + .action((_, c) => c.withLogProblems(true)), + opt[Unit]("log-preprocessor") + .text(s"enables logging of all preprocessor statements while generating the CPG") + .action((_, c) => c.withLogPreprocessor(true)), + opt[Unit]("print-ifdef-only") + .text( + s"prints a comma-separated list of all preprocessor ifdef and if statements; does not create a CPG" + ) + .action((_, c) => c.withPrintIfDefsOnly(true)), + opt[String]("include") + .unbounded() + .text("header include paths") + .action((incl, c) => c.withIncludePaths(c.includePaths + incl)), + opt[String]("include-files") + .unbounded() + .text("header include files") + .action((inclf, c) => c.withIncludeFiles(c.includeFiles + inclf)), + opt[String]("macro-files") + .unbounded() + .text("macro files") + .action((macrof, c) => c.withMacroFiles(c.macroFiles + macrof)), + opt[Unit]("no-include-auto-discovery") + .text("disables auto discovery of system header include paths") + .hidden(), + opt[Unit]("with-include-auto-discovery") + .text("enables auto discovery of system header include paths") + .action((_, c) => c.withIncludePathsAutoDiscovery(true)), + opt[Unit]("with-function-bodies") + .text("instructs the parser to parse function and method bodies.") + .action((_, c) => c.withFunctionBodies(true)), + opt[Unit]("with-image-locations") + .text( + "allows the parser to create image-locations. An image location explains how a name made it into the translation unit. Eg: via macro expansion or preprocessor." + ) + .action((_, c) => c.withImageLocations(true)), + opt[Unit]("with-project-index") + .text( + "performance optimization, allows the parser to use an existing eclipse project(s) index(es)." + ) + .action((_, c) => c.withProjectIndexes(true)), + opt[String]("define") + .unbounded() + .text("define a name") + .action((d, c) => c.withDefines(c.defines + d)) + ) + end cmdLineParser end Frontend object Main extends X2CpgMain(cmdLineParser, new C2Cpg()): - private val logger = LoggerFactory.getLogger(classOf[C2Cpg]) - - def run(config: Config, c2cpg: C2Cpg): Unit = - if config.printIfDefsOnly then - try - c2cpg.printIfDefsOnly(config) - catch - case NonFatal(ex) => - logger.debug("Failed to print preprocessor statements.", ex) - throw ex - else - c2cpg.run(config) + private val logger = LoggerFactory.getLogger(classOf[C2Cpg]) + + def run(config: Config, c2cpg: C2Cpg): Unit = + if config.printIfDefsOnly then + try + c2cpg.printIfDefsOnly(config) + catch + case NonFatal(ex) => + logger.debug("Failed to print preprocessor statements.", ex) + throw ex + else + c2cpg.run(config) diff --git a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/astcreation/AstCreator.scala b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/astcreation/AstCreator.scala index 1617bb76..0ad0704d 100644 --- a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/astcreation/AstCreator.scala +++ b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/astcreation/AstCreator.scala @@ -38,82 +38,81 @@ class AstCreator( with MacroHandler with X2CpgAstNodeBuilder[IASTNode, AstCreator]: - protected val logger: Logger = LoggerFactory.getLogger(classOf[AstCreator]) + protected val logger: Logger = LoggerFactory.getLogger(classOf[AstCreator]) - protected val scope: Scope[String, (NewNode, String), NewNode] = new Scope() + protected val scope: Scope[String, (NewNode, String), NewNode] = new Scope() - protected val usingDeclarationMappings: mutable.Map[String, String] = mutable.HashMap.empty + protected val usingDeclarationMappings: mutable.Map[String, String] = mutable.HashMap.empty - // TypeDecls with their bindings (with their refs) for lambdas and methods are not put in the AST - // where the respective nodes are defined. Instead we put them under the parent TYPE_DECL in which they are defined. - // To achieve this we need this extra stack. - protected val methodAstParentStack: Stack[NewNode] = new Stack() + // TypeDecls with their bindings (with their refs) for lambdas and methods are not put in the AST + // where the respective nodes are defined. Instead we put them under the parent TYPE_DECL in which they are defined. + // To achieve this we need this extra stack. + protected val methodAstParentStack: Stack[NewNode] = new Stack() - def createAst(): DiffGraphBuilder = - val ast = astForTranslationUnit(cdtAst) - Ast.storeInDiffGraph(ast, diffGraph) - diffGraph + def createAst(): DiffGraphBuilder = + val ast = astForTranslationUnit(cdtAst) + Ast.storeInDiffGraph(ast, diffGraph) + diffGraph - private def astForTranslationUnit(iASTTranslationUnit: IASTTranslationUnit): Ast = - val namespaceBlock = globalNamespaceBlock() - methodAstParentStack.push(namespaceBlock) - val translationUnitAst = - astInFakeMethod( - namespaceBlock.fullName, - fileName(iASTTranslationUnit), - iASTTranslationUnit - ) - val depsAndImportsAsts = astsForDependenciesAndImports(iASTTranslationUnit) - val commentsAsts = astsForComments(iASTTranslationUnit) - val childrenAsts = depsAndImportsAsts ++ Seq(translationUnitAst) ++ commentsAsts - setArgumentIndices(childrenAsts) - Ast(namespaceBlock).withChildren(childrenAsts) + private def astForTranslationUnit(iASTTranslationUnit: IASTTranslationUnit): Ast = + val namespaceBlock = globalNamespaceBlock() + methodAstParentStack.push(namespaceBlock) + val translationUnitAst = + astInFakeMethod( + namespaceBlock.fullName, + fileName(iASTTranslationUnit), + iASTTranslationUnit + ) + val depsAndImportsAsts = astsForDependenciesAndImports(iASTTranslationUnit) + val commentsAsts = astsForComments(iASTTranslationUnit) + val childrenAsts = depsAndImportsAsts ++ Seq(translationUnitAst) ++ commentsAsts + setArgumentIndices(childrenAsts) + Ast(namespaceBlock).withChildren(childrenAsts) - /** Creates an AST of all declarations found in the translation unit - wrapped in a fake method. - */ - private def astInFakeMethod( - fullName: String, - path: String, - iASTTranslationUnit: IASTTranslationUnit - ): Ast = - val allDecls = iASTTranslationUnit.getDeclarations.toList.filterNot(isIncludedNode) - val name = NamespaceTraversal.globalNamespaceName + /** Creates an AST of all declarations found in the translation unit - wrapped in a fake method. + */ + private def astInFakeMethod( + fullName: String, + path: String, + iASTTranslationUnit: IASTTranslationUnit + ): Ast = + val allDecls = iASTTranslationUnit.getDeclarations.toList.filterNot(isIncludedNode) + val name = NamespaceTraversal.globalNamespaceName - val fakeGlobalTypeDecl = - typeDeclNode( - iASTTranslationUnit, - name, - fullName, - filename, - name, - NodeTypes.NAMESPACE_BLOCK, - fullName - ) - methodAstParentStack.push(fakeGlobalTypeDecl) + val fakeGlobalTypeDecl = + typeDeclNode( + iASTTranslationUnit, + name, + fullName, + filename, + name, + NodeTypes.NAMESPACE_BLOCK, + fullName + ) + methodAstParentStack.push(fakeGlobalTypeDecl) - val fakeGlobalMethod = - methodNode( - iASTTranslationUnit, - name, - name, - fullName, - None, - path, - Option(NodeTypes.TYPE_DECL), - Option(fullName) - ) - methodAstParentStack.push(fakeGlobalMethod) - scope.pushNewScope(fakeGlobalMethod) + val fakeGlobalMethod = + methodNode( + iASTTranslationUnit, + name, + name, + fullName, + None, + path, + Option(NodeTypes.TYPE_DECL), + Option(fullName) + ) + methodAstParentStack.push(fakeGlobalMethod) + scope.pushNewScope(fakeGlobalMethod) - val blockNode_ = - blockNode(iASTTranslationUnit, Defines.empty, registerType(Defines.anyTypeName)) + val blockNode_ = blockNode(iASTTranslationUnit) - val declsAsts = allDecls.flatMap(astsForDeclaration) - setArgumentIndices(declsAsts) + val declsAsts = allDecls.flatMap(astsForDeclaration) + setArgumentIndices(declsAsts) - val methodReturn = newMethodReturnNode(iASTTranslationUnit, Defines.anyTypeName) - Ast(fakeGlobalTypeDecl).withChild( - methodAst(fakeGlobalMethod, Seq.empty, blockAst(blockNode_, declsAsts), methodReturn) - ) - end astInFakeMethod + val methodReturn = newMethodReturnNode(iASTTranslationUnit, Defines.anyTypeName) + Ast(fakeGlobalTypeDecl).withChild( + methodAst(fakeGlobalMethod, Seq.empty, blockAst(blockNode_, declsAsts), methodReturn) + ) + end astInFakeMethod end AstCreator diff --git a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/astcreation/AstCreatorHelper.scala b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/astcreation/AstCreatorHelper.scala index eb0480f6..8d01351f 100644 --- a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/astcreation/AstCreatorHelper.scala +++ b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/astcreation/AstCreatorHelper.scala @@ -2,8 +2,9 @@ package io.appthreat.c2cpg.astcreation import io.appthreat.c2cpg.datastructures.CGlobal import io.appthreat.x2cpg.utils.NodeBuilders.newDependencyNode +import io.appthreat.x2cpg.Defines as X2CpgDefines import io.appthreat.x2cpg.{Ast, SourceFiles, ValidationMode} -import io.shiftleft.codepropertygraph.generated.nodes.{ExpressionNew, NewNode} +import io.shiftleft.codepropertygraph.generated.nodes.{ExpressionNew, NewCall, NewNode} import io.shiftleft.codepropertygraph.generated.{DispatchTypes, EdgeTypes, Operators} import io.shiftleft.utils.IOUtils import org.apache.commons.lang.StringUtils @@ -15,7 +16,7 @@ import org.eclipse.cdt.core.dom.ast.c.{ } import org.eclipse.cdt.core.dom.ast.cpp.* import org.eclipse.cdt.core.dom.ast.gnu.c.ICASTKnRFunctionDeclarator -import org.eclipse.cdt.internal.core.dom.parser.c.CASTArrayRangeDesignator +import org.eclipse.cdt.internal.core.dom.parser.c.{CASTArrayRangeDesignator, CASTFunctionDeclarator} import org.eclipse.cdt.internal.core.dom.parser.cpp.* import org.eclipse.cdt.internal.core.dom.parser.cpp.semantics.{EvalBinding, EvalMemberAccess} import org.eclipse.cdt.internal.core.model.ASTStringUtil @@ -27,558 +28,614 @@ import scala.util.Try object AstCreatorHelper: - implicit class OptionSafeAst(val ast: Ast) extends AnyVal: - def withArgEdge(src: NewNode, dst: Option[NewNode]): Ast = dst match - case Some(value) => ast.withArgEdge(src, value) - case None => ast + implicit class OptionSafeAst(val ast: Ast) extends AnyVal: + def withArgEdge(src: NewNode, dst: Option[NewNode]): Ast = dst match + case Some(value) => ast.withArgEdge(src, value) + case None => ast trait AstCreatorHelper(implicit withSchemaValidation: ValidationMode): - this: AstCreator => - - import AstCreatorHelper.* - - private val IncludeKeyword = "include" - // Sadly, there is no predefined List / Enum of this within Eclipse CDT: - private val reservedTypeKeywords: List[String] = - List( - "const", - "static", - "volatile", - "restrict", - "extern", - "typedef", - "inline", - "constexpr", - "auto", - "virtual", - "enum", - "struct", - "interface", - "class", - "naked", - "export", - "module", - "import" - ) - private var usedVariablePostfix: Int = 0 - - protected def uniqueName(target: String, name: String, fullName: String): (String, String) = - if name.isEmpty && (fullName.isEmpty || fullName.endsWith(".")) then - val name = s"anonymous_${target}_$usedVariablePostfix" - val resultingFullName = s"$fullName$name" - usedVariablePostfix = usedVariablePostfix + 1 - (name, resultingFullName) - else - (name, fullName) - - protected def code(node: IASTNode): String = shortenCode(nodeSignature(node)) - - protected def line(node: IASTNode): Option[Integer] = - nullSafeFileLocation(node).map(_.getStartingLineNumber) - - protected def lineEnd(node: IASTNode): Option[Integer] = - nullSafeFileLocationLast(node).map(_.getEndingLineNumber) - - protected def column(node: IASTNode): Option[Integer] = - val loc = nullSafeFileLocation(node) - loc.map { x => - offsetToColumn(node, x.getNodeOffset) - } - - private def offsetToColumn(node: IASTNode, offset: Int): Int = - val table = fileOffsetTable(node) - val index = java.util.Arrays.binarySearch(table, offset) - val tableIndex = if index < 0 then -(index + 1) else index + 1 - val lineStartOffset = if tableIndex == 0 then - 0 + this: AstCreator => + + import AstCreatorHelper.* + + private val IncludeKeyword = "include" + // Sadly, there is no predefined List / Enum of this within Eclipse CDT: + private val reservedTypeKeywords: List[String] = + List( + "const", + "static", + "volatile", + "restrict", + "extern", + "typedef", + "inline", + "constexpr", + "auto", + "virtual", + "enum", + "struct", + "interface", + "class", + "naked", + "export", + "module", + "import" + ) + private var usedVariablePostfix: Int = 0 + + def createCallAst( + callNode: NewCall, + arguments: Seq[Ast] = List(), + base: Option[Ast] = None, + receiver: Option[Ast] = None + ): Ast = + + setArgumentIndices(arguments) + + val baseRoot = base.flatMap(_.root).toList + val bse = base.getOrElse(Ast()) + baseRoot match + case List(x: ExpressionNew) => + x.argumentIndex = 0 + case _ => + + var ast = + Ast(callNode) + .withChild(bse) + + if receiver.isDefined && receiver != base then + receiver.get.root.get.asInstanceOf[ExpressionNew].argumentIndex = -1 + ast = ast.withChild(receiver.get) + + ast = ast + .withChildren(arguments) + .withArgEdges(callNode, baseRoot) + .withArgEdges(callNode, arguments.flatMap(_.root)) + + if receiver.isDefined then + ast = ast.withReceiverEdge(callNode, receiver.get.root.get) + + ast + end createCallAst + + protected def uniqueName(target: String, name: String, fullName: String): (String, String) = + if name.isEmpty && (fullName.isEmpty || fullName.endsWith(".")) then + val name = s"anonymous_${target}_$usedVariablePostfix" + val resultingFullName = s"$fullName$name" + usedVariablePostfix = usedVariablePostfix + 1 + (name, resultingFullName) + else + (name, fullName) + + protected def code(node: IASTNode): String = shortenCode(nodeSignature(node)) + + protected def line(node: IASTNode): Option[Integer] = + nullSafeFileLocation(node).map(_.getStartingLineNumber) + + protected def lineEnd(node: IASTNode): Option[Integer] = + nullSafeFileLocationLast(node).map(_.getEndingLineNumber) + + private def nullSafeFileLocationLast(node: IASTNode): Option[IASTFileLocation] = + Option(cdtAst.flattenLocationsToFile(node.getNodeLocations.lastOption.toArray)).map( + _.asFileLocation() + ) + + protected def column(node: IASTNode): Option[Integer] = + val loc = nullSafeFileLocation(node) + loc.map { x => + offsetToColumn(node, x.getNodeOffset) + } + + protected def columnEnd(node: IASTNode): Option[Integer] = + val loc = nullSafeFileLocation(node) + loc.map { x => + offsetToColumn(node, x.getNodeOffset + x.getNodeLength - 1) + } + + private def offsetToColumn(node: IASTNode, offset: Int): Int = + val table = fileOffsetTable(node) + val index = java.util.Arrays.binarySearch(table, offset) + val tableIndex = if index < 0 then -(index + 1) else index + 1 + val lineStartOffset = if tableIndex == 0 then + 0 + else + table(tableIndex - 1) + val column = offset - lineStartOffset + 1 + column + + private def fileOffsetTable(node: IASTNode): Array[Int] = + val path = SourceFiles.toAbsolutePath(fileName(node), config.inputPath) + file2OffsetTable.computeIfAbsent(path, _ => genFileOffsetTable(Paths.get(path))) + + private def genFileOffsetTable(absolutePath: Path): Array[Int] = + val asCharArray = IOUtils.readLinesInFile(absolutePath).mkString("\n").toCharArray + val offsets = mutable.ArrayBuffer.empty[Int] + + for i <- Range(0, asCharArray.length) do + if asCharArray(i) == '\n' then + offsets.append(i + 1) + offsets.toArray + + protected def fileName(node: IASTNode): String = + val path = nullSafeFileLocation(node).map(_.getFileName).getOrElse(filename) + SourceFiles.toRelativePath(path, config.inputPath) + + private def nullSafeFileLocation(node: IASTNode): Option[IASTFileLocation] = + Option(cdtAst.flattenLocationsToFile(node.getNodeLocations)).map(_.asFileLocation()) + + protected def registerType(typeName: String): String = + val fixedTypeName = fixQualifiedName(StringUtils.normalizeSpace(typeName)) + CGlobal.usedTypes.putIfAbsent(fixedTypeName, true) + fixedTypeName + + protected def fixQualifiedName(name: String): String = + name.stripPrefix(Defines.qualifiedNameSeparator).replace( + Defines.qualifiedNameSeparator, + "." + ) + + protected def cleanType(rawType: String, stripKeywords: Boolean = true): String = + val tpe = + if stripKeywords then + reservedTypeKeywords.foldLeft(rawType) { (cur, repl) => + if cur.contains(s"$repl ") then + dereferenceTypeFullName(cur.replace(s"$repl ", "")) + else + cur + } else - table(tableIndex - 1) - val column = offset - lineStartOffset + 1 - column - - private def fileOffsetTable(node: IASTNode): Array[Int] = - val path = SourceFiles.toAbsolutePath(fileName(node), config.inputPath) - file2OffsetTable.computeIfAbsent(path, _ => genFileOffsetTable(Paths.get(path))) - - private def genFileOffsetTable(absolutePath: Path): Array[Int] = - val asCharArray = IOUtils.readLinesInFile(absolutePath).mkString("\n").toCharArray - val offsets = mutable.ArrayBuffer.empty[Int] - - for i <- Range(0, asCharArray.length) do - if asCharArray(i) == '\n' then - offsets.append(i + 1) - offsets.toArray - - protected def fileName(node: IASTNode): String = - val path = nullSafeFileLocation(node).map(_.getFileName).getOrElse(filename) - SourceFiles.toRelativePath(path, config.inputPath) - - private def nullSafeFileLocation(node: IASTNode): Option[IASTFileLocation] = - Option(cdtAst.flattenLocationsToFile(node.getNodeLocations)).map(_.asFileLocation()) - - protected def columnEnd(node: IASTNode): Option[Integer] = - val loc = nullSafeFileLocation(node) - loc.map { x => - offsetToColumn(node, x.getNodeOffset + x.getNodeLength - 1) - } - - protected def registerType(typeName: String): String = - val fixedTypeName = fixQualifiedName(StringUtils.normalizeSpace(typeName)) - CGlobal.usedTypes.putIfAbsent(fixedTypeName, true) - fixedTypeName - - protected def fixQualifiedName(name: String): String = - name.stripPrefix(Defines.qualifiedNameSeparator).replace( - Defines.qualifiedNameSeparator, - "." - ) - - protected def cleanType(rawType: String, stripKeywords: Boolean = true): String = - val tpe = - if stripKeywords then - reservedTypeKeywords.foldLeft(rawType) { (cur, repl) => - if cur.contains(s"$repl ") then - dereferenceTypeFullName(cur.replace(s"$repl ", "")) - else - cur - } - else - rawType - StringUtils.normalizeSpace(tpe) match - case "" => Defines.anyTypeName - case t if t.contains("org.eclipse.cdt.internal.core.dom.parser.ProblemType") => - Defines.anyTypeName - case t if t.contains(" ->") && t.contains("}::") => - fixQualifiedName(t.substring(t.indexOf("}::") + 3, t.indexOf(" ->"))) - case t if t.contains(" ->") => - fixQualifiedName(t.substring(0, t.indexOf(" ->"))) - case t if t.contains("( ") => - fixQualifiedName(t.substring(0, t.indexOf("( "))) - case t if t.contains("?") => Defines.anyTypeName - case t if t.contains("#") => Defines.anyTypeName - case t if t.contains("{") && t.contains("}") => - val anonType = - s"${uniqueName("type", "", "")._1}${t.substring(0, t.indexOf("{"))}${t.substring(t.indexOf("}") + 1)}" - anonType.replace(" ", "") - case t if t.startsWith("[") && t.endsWith("]") => Defines.anyTypeName - case t if t.contains(Defines.qualifiedNameSeparator) => fixQualifiedName(t) - case t if t.startsWith("unsigned ") => "unsigned " + t.substring(9).replace(" ", "") - case t if t.contains("[") && t.contains("]") => t.replace(" ", "") - case t if t.contains("*") => t.replace(" ", "") - case someType => someType - end match - end cleanType - - @nowarn - protected def typeFor(node: IASTNode, stripKeywords: Boolean = true): String = - import org.eclipse.cdt.core.dom.ast.ASTSignatureUtil.getNodeSignature - node match - case f: CPPASTFieldReference => - safeGetEvaluation(f.getFieldOwner) match - case Some(evaluation: EvalBinding) => - cleanType(evaluation.getType.toString, stripKeywords) - case _ => cleanType( - ASTTypeUtil.getType(f.getFieldOwner.getExpressionType), - stripKeywords - ) - case f: IASTFieldReference => - cleanType(ASTTypeUtil.getType(f.getFieldOwner.getExpressionType), stripKeywords) - case a: IASTArrayDeclarator if ASTTypeUtil.getNodeType(a).startsWith("? ") => - val tpe = getNodeSignature(a).replace("[]", "").strip() - val arr = ASTTypeUtil.getNodeType(a).replace("? ", "") - s"$tpe$arr" - case a: IASTArrayDeclarator if ASTTypeUtil.getNodeType(a).contains("} ") => - val tpe = getNodeSignature(a).replace("[]", "").strip() - val nodeType = ASTTypeUtil.getNodeType(node) - val arr = nodeType.substring(nodeType.indexOf("["), nodeType.indexOf("]") + 1) - s"$tpe$arr" - case a: IASTArrayDeclarator if ASTTypeUtil.getNodeType(a).contains(" [") => - cleanType(ASTTypeUtil.getNodeType(node)) - case s: CPPASTIdExpression => - safeGetEvaluation(s) match - case Some(evaluation: EvalMemberAccess) => - cleanType(evaluation.getOwnerType.toString, stripKeywords) - case Some(evalBinding: EvalBinding) => - evalBinding.getBinding match - case m: CPPMethod => cleanType(fullName(m.getDefinition)) - case _ => cleanType(ASTTypeUtil.getNodeType(s), stripKeywords) - case _ => cleanType(ASTTypeUtil.getNodeType(s), stripKeywords) - case _: IASTIdExpression | _: IASTName | _: IASTDeclarator => - cleanType(ASTTypeUtil.getNodeType(node), stripKeywords) - case s: IASTNamedTypeSpecifier => - cleanType(ASTStringUtil.getReturnTypeString(s, null), stripKeywords) - case s: IASTCompositeTypeSpecifier => - cleanType(ASTStringUtil.getReturnTypeString(s, null), stripKeywords) - case s: IASTEnumerationSpecifier => - cleanType(ASTStringUtil.getReturnTypeString(s, null), stripKeywords) - case s: IASTElaboratedTypeSpecifier => - cleanType(ASTStringUtil.getReturnTypeString(s, null), stripKeywords) - case l: IASTLiteralExpression => - cleanType(ASTTypeUtil.getType(l.getExpressionType)) - case e: IASTExpression => - cleanType(ASTTypeUtil.getNodeType(e), stripKeywords) - case c: ICPPASTConstructorInitializer - if c.getParent.isInstanceOf[ICPPASTConstructorChainInitializer] => - cleanType( - fullName(c.getParent.asInstanceOf[ - ICPPASTConstructorChainInitializer - ].getMemberInitializerId), + rawType + StringUtils.normalizeSpace(tpe) match + case "" => Defines.anyTypeName + case t if t.contains("org.eclipse.cdt.internal.core.dom.parser.ProblemType") => + Defines.anyTypeName + case t if t.contains(" ->") && t.contains("}::") => + fixQualifiedName(t.substring(t.indexOf("}::") + 3, t.indexOf(" ->"))) + case t if t.contains(" ->") => + fixQualifiedName(t.substring(0, t.indexOf(" ->"))) + case t if t.contains("( ") => + fixQualifiedName(t.substring(0, t.indexOf("( "))) + case t if t.contains("?") => Defines.anyTypeName + case t if t.contains("#") => Defines.anyTypeName + case t if t.contains("{") && t.contains("}") => + val anonType = + s"${uniqueName("type", "", "")._1}${t + .substring(0, t.indexOf("{"))}${t.substring(t.indexOf("}") + 1)}" + anonType.replace(" ", "") + case t if t.startsWith("[") && t.endsWith("]") => Defines.anyTypeName + case t if t.contains(Defines.qualifiedNameSeparator) => fixQualifiedName(t) + case t if t.startsWith("unsigned ") => "unsigned " + t.substring(9).replace(" ", "") + case t if t.contains("[") && t.contains("]") => t.replace(" ", "") + case t if t.contains("*") => t.replace(" ", "") + case someType => someType + end match + end cleanType + + @nowarn + protected def typeFor(node: IASTNode, stripKeywords: Boolean = true): String = + import org.eclipse.cdt.core.dom.ast.ASTSignatureUtil.getNodeSignature + node match + case f: CPPASTFieldReference => + safeGetEvaluation(f.getFieldOwner) match + case Some(evaluation: EvalBinding) => + cleanType(evaluation.getType.toString, stripKeywords) + case _ => cleanType( + ASTTypeUtil.getType(f.getFieldOwner.getExpressionType), stripKeywords ) - case _ => - cleanType(getNodeSignature(node), stripKeywords) - end match - end typeFor - - protected def notHandledYet(node: IASTNode): Ast = - if !node.isInstanceOf[IASTProblem] && !node.isInstanceOf[IASTProblemHolder] then - val text = notHandledText(node) - logger.debug(text) - Ast(unknownNode(node, nodeSignature(node))) - - protected def nullSafeCode(node: IASTNode): String = - Option(node).map(nodeSignature).getOrElse("") - - protected def nullSafeAst(node: IASTExpression, argIndex: Int): Ast = - val r = nullSafeAst(node) - r.root match - case Some(x: ExpressionNew) => - x.argumentIndex = argIndex - case _ => - r - - protected def nullSafeAst(node: IASTExpression): Ast = - Option(node).map(astForNode).getOrElse(Ast()) - - protected def nullSafeAst(node: IASTStatement, argIndex: Int = -1): Seq[Ast] = - Option(node).map(astsForStatement(_, argIndex)).getOrElse(Seq.empty) - - protected def dereferenceTypeFullName(fullName: String): String = - fullName.replace("*", "") - - protected def isQualifiedName(name: String): Boolean = - name.startsWith(Defines.qualifiedNameSeparator) - - protected def lastNameOfQualifiedName(name: String): String = - val cleanedName = if name.contains("<") && name.contains(">") then - name.substring(0, name.indexOf("<")) - else - name - cleanedName.split(Defines.qualifiedNameSeparator).lastOption.getOrElse(cleanedName) - - protected def fullName(node: IASTNode): String = - val filename = fileName(node) - val lineNo: Integer = line(node).getOrElse(-1) - val lineNoEnd: Integer = lineEnd(node).getOrElse(-1) - val qualifiedName: String = node match - case d: CPPASTIdExpression if d.getEvaluation.isInstanceOf[EvalBinding] => - val evaluation = d.getEvaluation.asInstanceOf[EvalBinding] - evaluation.getBinding match - case f: CPPFunction if f.getDeclarations != null => - usingDeclarationMappings.getOrElse( - fixQualifiedName(ASTStringUtil.getSimpleName(d.getName)), - f.getDeclarations.headOption.map(n => - ASTStringUtil.getSimpleName(n.getName) - ).getOrElse(f.getName) - ) - case f: CPPFunction if f.getDefinition != null => - usingDeclarationMappings.getOrElse( - fixQualifiedName(ASTStringUtil.getSimpleName(d.getName)), - ASTStringUtil.getSimpleName(f.getDefinition.getName) - ) - case other => other.getName - case alias: ICPPASTNamespaceAlias => alias.getMappingName.toString - case namespace: ICPPASTNamespaceDefinition - if ASTStringUtil.getSimpleName(namespace.getName).nonEmpty => - s"${fullName(namespace.getParent)}.${ASTStringUtil.getSimpleName(namespace.getName)}" - case namespace: ICPPASTNamespaceDefinition - if ASTStringUtil.getSimpleName(namespace.getName).isEmpty => - s"${fullName(namespace.getParent)}.${uniqueName("namespace", "", "")._1}" - case compType: IASTCompositeTypeSpecifier - if ASTStringUtil.getSimpleName(compType.getName).nonEmpty => - s"${fullName(compType.getParent)}.${ASTStringUtil.getSimpleName(compType.getName)}" - case compType: IASTCompositeTypeSpecifier - if ASTStringUtil.getSimpleName(compType.getName).isEmpty => - val name = compType.getParent match - case decl: IASTSimpleDeclaration => - decl.getDeclarators.headOption - .map(n => ASTStringUtil.getSimpleName(n.getName)) - .getOrElse(uniqueName("composite_type", "", "")._1) - case _ => uniqueName("composite_type", "", "")._1 - s"${fullName(compType.getParent)}.$name" - case enumSpecifier: IASTEnumerationSpecifier => - s"${fullName(enumSpecifier.getParent)}.${ASTStringUtil.getSimpleName(enumSpecifier.getName)}" - case f: ICPPASTLambdaExpression => - s"${fullName(f.getParent)}." - case f: IASTFunctionDeclarator - if ASTStringUtil.getSimpleName( - f.getName - ).isEmpty && f.getNestedDeclarator != null => - val parentFullName = fullName(f.getParent) - val sn = shortName(f.getNestedDeclarator) - val fnWithParent = - if parentFullName.nonEmpty then s"${parentFullName}.${sn}" else sn - s"$filename:$lineNo:$lineNoEnd:${fnWithParent}" - case f: IASTFunctionDeclarator => - val parentFullName = fullName(f.getParent) - val sn = ASTStringUtil.getSimpleName(f.getName) - val fnWithParent = - if parentFullName.nonEmpty then s"${parentFullName}.${sn}" else sn - s"$filename:$lineNo:$lineNoEnd:${fnWithParent}" - case f: IASTFunctionDefinition if f.getDeclarator != null => - s"${fullName(f.getParent)}.${ASTStringUtil.getQualifiedName(f.getDeclarator.getName)}" - case f: IASTFunctionDefinition => - s"${fullName(f.getParent)}.${shortName(f)}" - case e: IASTElaboratedTypeSpecifier => - s"${fullName(e.getParent)}.${ASTStringUtil.getSimpleName(e.getName)}" - case d: IASTIdExpression => ASTStringUtil.getSimpleName(d.getName) - case _: IASTTranslationUnit => "" - case u: IASTUnaryExpression => nodeSignature(u.getOperand) - case x: ICPPASTQualifiedName => ASTStringUtil.getQualifiedName(x) - case other if other.getParent != null => fullName(other.getParent) - case other if other != null => notHandledYet(other); "" - case null => "" - fixQualifiedName(qualifiedName).stripPrefix(".") - end fullName - - protected def shortName(node: IASTNode): String = - val name = node match - case d: IASTDeclarator - if ASTStringUtil.getSimpleName( - d.getName - ).isEmpty && d.getNestedDeclarator != null => - shortName(d.getNestedDeclarator) - case d: IASTDeclarator => ASTStringUtil.getSimpleName(d.getName) - case f: ICPPASTFunctionDefinition - if ASTStringUtil - .getSimpleName(f.getDeclarator.getName) - .isEmpty && f.getDeclarator.getNestedDeclarator != null => - shortName(f.getDeclarator.getNestedDeclarator) - case f: ICPPASTFunctionDefinition => - lastNameOfQualifiedName(ASTStringUtil.getSimpleName(f.getDeclarator.getName)) - case f: IASTFunctionDefinition - if ASTStringUtil - .getSimpleName(f.getDeclarator.getName) - .isEmpty && f.getDeclarator.getNestedDeclarator != null => - shortName(f.getDeclarator.getNestedDeclarator) - case f: IASTFunctionDefinition => ASTStringUtil.getSimpleName(f.getDeclarator.getName) - case d: CPPASTIdExpression if d.getEvaluation.isInstanceOf[EvalBinding] => - val evaluation = d.getEvaluation.asInstanceOf[EvalBinding] - evaluation.getBinding match - case f: CPPFunction if f.getDeclarations != null => - f.getDeclarations.headOption.map(n => - ASTStringUtil.getSimpleName(n.getName) - ).getOrElse(f.getName) - case f: CPPFunction if f.getDefinition != null => - ASTStringUtil.getSimpleName(f.getDefinition.getName) - case other => - other.getName - case d: IASTIdExpression => - lastNameOfQualifiedName(ASTStringUtil.getSimpleName(d.getName)) - case u: IASTUnaryExpression => shortName(u.getOperand) - case c: IASTFunctionCallExpression => shortName(c.getFunctionNameExpression) - case s: IASTSimpleDeclSpecifier => s.getRawSignature - case e: IASTEnumerationSpecifier => ASTStringUtil.getSimpleName(e.getName) - case c: IASTCompositeTypeSpecifier => ASTStringUtil.getSimpleName(c.getName) - case e: IASTElaboratedTypeSpecifier => ASTStringUtil.getSimpleName(e.getName) - case s: IASTNamedTypeSpecifier => ASTStringUtil.getSimpleName(s.getName) - case other => notHandledYet(other); "" - name - end shortName - - protected def astsForDependenciesAndImports(iASTTranslationUnit: IASTTranslationUnit) - : Seq[Ast] = - val allIncludes = iASTTranslationUnit.getIncludeDirectives.toList.filterNot(isIncludedNode) - allIncludes.map { include => - val name = include.getName.toString - val _dependencyNode = newDependencyNode(name, name, IncludeKeyword) - val importNode = newImportNode(nodeSignature(include), name, name, include) - diffGraph.addNode(_dependencyNode) - diffGraph.addEdge(importNode, _dependencyNode, EdgeTypes.IMPORTS) - Ast(importNode) - } - - protected def isIncludedNode(node: IASTNode): Boolean = fileName(node) != filename - - protected def astsForComments(iASTTranslationUnit: IASTTranslationUnit): Seq[Ast] = - if config.includeComments then - iASTTranslationUnit.getComments.toList.filterNot(isIncludedNode).map(comment => - astForComment(comment) - ) - else - Seq.empty - - protected def astForNode(node: IASTNode): Ast = - if config.includeFunctionBodies then astForNodeFull(node) else astForNodePartial(node) - - protected def astForNodeFull(node: IASTNode): Ast = - node match - case expr: IASTExpression => astForExpression(expr) - case name: IASTName => astForIdentifier(name) - case decl: IASTDeclSpecifier => astForIdentifier(decl) - case l: IASTInitializerList => astForInitializerList(l) - case c: ICPPASTConstructorInitializer => astForCPPASTConstructorInitializer(c) - case d: ICASTDesignatedInitializer => astForCASTDesignatedInitializer(d) - case d: ICPPASTDesignatedInitializer => astForCPPASTDesignatedInitializer(d) - case d: CASTArrayRangeDesignator => astForCASTArrayRangeDesignator(d) - case d: CPPASTArrayRangeDesignator => astForCPPASTArrayRangeDesignator(d) - case d: ICASTArrayDesignator => nullSafeAst(d.getSubscriptExpression) - case d: ICPPASTArrayDesignator => nullSafeAst(d.getSubscriptExpression) - case d: ICPPASTFieldDesignator => astForNode(d.getName) - case d: ICASTFieldDesignator => astForNode(d.getName) - case decl: ICPPASTDecltypeSpecifier => astForDecltypeSpecifier(decl) - case arrMod: IASTArrayModifier => astForArrayModifier(arrMod) - case _ => notHandledYet(node) - - protected def astForNodePartial(node: IASTNode): Ast = - node match - case expr: IASTExpression => astForExpression(expr) - case name: IASTName => astForIdentifier(name) - case decl: IASTDeclSpecifier => astForIdentifier(decl) - case decl: ICPPASTDecltypeSpecifier => astForDecltypeSpecifier(decl) - case _ => notHandledYet(node) - - protected def typeForDeclSpecifier( - spec: IASTNode, - stripKeywords: Boolean = true, - index: Int = 0 - ): String = - val tpe = spec match - case s: IASTSimpleDeclSpecifier if s.getParent.isInstanceOf[IASTParameterDeclaration] => - val parentDecl = s.getParent.asInstanceOf[IASTParameterDeclaration].getDeclarator - pointersAsString(s, parentDecl, stripKeywords) - case s: IASTSimpleDeclSpecifier if s.getParent.isInstanceOf[IASTFunctionDefinition] => - val parentDecl = s.getParent.asInstanceOf[IASTFunctionDefinition].getDeclarator - ASTStringUtil.getReturnTypeString(s, parentDecl) - case s: IASTSimpleDeclaration if s.getParent.isInstanceOf[ICASTKnRFunctionDeclarator] => - val decl = s.getDeclarators.toList(index) - pointersAsString(s.getDeclSpecifier, decl, stripKeywords) - case s: IASTSimpleDeclSpecifier if s.getParent.isInstanceOf[IASTSimpleDeclaration] => - val parentDecl = - s.getParent.asInstanceOf[IASTSimpleDeclaration].getDeclarators.toList(index) - pointersAsString(s, parentDecl, stripKeywords) - case s: IASTSimpleDeclSpecifier => - ASTStringUtil.getReturnTypeString(s, null) - case s: IASTNamedTypeSpecifier if s.getParent.isInstanceOf[IASTParameterDeclaration] => - val parentDecl = s.getParent.asInstanceOf[IASTParameterDeclaration].getDeclarator - pointersAsString(s, parentDecl, stripKeywords) - case s: IASTNamedTypeSpecifier if s.getParent.isInstanceOf[IASTSimpleDeclaration] => - val parentDecl = - s.getParent.asInstanceOf[IASTSimpleDeclaration].getDeclarators.toList(index) - pointersAsString(s, parentDecl, stripKeywords) - case s: IASTNamedTypeSpecifier => ASTStringUtil.getSimpleName(s.getName) - case s: IASTCompositeTypeSpecifier if s.getParent.isInstanceOf[IASTSimpleDeclaration] => - val parentDecl = - s.getParent.asInstanceOf[IASTSimpleDeclaration].getDeclarators.toList(index) - pointersAsString(s, parentDecl, stripKeywords) - case s: IASTCompositeTypeSpecifier => ASTStringUtil.getSimpleName(s.getName) - case s: IASTEnumerationSpecifier if s.getParent.isInstanceOf[IASTSimpleDeclaration] => - val parentDecl = - s.getParent.asInstanceOf[IASTSimpleDeclaration].getDeclarators.toList(index) - pointersAsString(s, parentDecl, stripKeywords) - case s: IASTEnumerationSpecifier => ASTStringUtil.getSimpleName(s.getName) - case s: IASTElaboratedTypeSpecifier - if s.getParent.isInstanceOf[IASTParameterDeclaration] => - val parentDecl = s.getParent.asInstanceOf[IASTParameterDeclaration].getDeclarator - pointersAsString(s, parentDecl, stripKeywords) - case s: IASTElaboratedTypeSpecifier - if s.getParent.isInstanceOf[IASTSimpleDeclaration] => - val parentDecl = - s.getParent.asInstanceOf[IASTSimpleDeclaration].getDeclarators.toList(index) - pointersAsString(s, parentDecl, stripKeywords) - case s: IASTElaboratedTypeSpecifier => ASTStringUtil.getSignatureString(s, null) - // TODO: handle other types of IASTDeclSpecifier - case _ => Defines.anyTypeName - if tpe.isEmpty then Defines.anyTypeName else tpe - end typeForDeclSpecifier - - private def nullSafeFileLocationLast(node: IASTNode): Option[IASTFileLocation] = - Option(cdtAst.flattenLocationsToFile(node.getNodeLocations.lastOption.toArray)).map( - _.asFileLocation() + case f: IASTFieldReference => + cleanType(ASTTypeUtil.getType(f.getFieldOwner.getExpressionType), stripKeywords) + case a: IASTArrayDeclarator if ASTTypeUtil.getNodeType(a).startsWith("? ") => + val tpe = getNodeSignature(a).replace("[]", "").strip() + val arr = ASTTypeUtil.getNodeType(a).replace("? ", "") + s"$tpe$arr" + case a: IASTArrayDeclarator if ASTTypeUtil.getNodeType(a).contains("} ") => + val tpe = getNodeSignature(a).replace("[]", "").strip() + val nodeType = ASTTypeUtil.getNodeType(node) + val arr = nodeType.substring(nodeType.indexOf("["), nodeType.indexOf("]") + 1) + s"$tpe$arr" + case a: IASTArrayDeclarator if ASTTypeUtil.getNodeType(a).contains(" [") => + cleanType(ASTTypeUtil.getNodeType(node)) + case s: CPPASTIdExpression => + safeGetEvaluation(s) match + case Some(evaluation: EvalMemberAccess) => + cleanType(evaluation.getOwnerType.toString, stripKeywords) + case Some(evalBinding: EvalBinding) => + evalBinding.getBinding match + case m: CPPMethod => cleanType(fullName(m.getDefinition)) + case _ => cleanType(ASTTypeUtil.getNodeType(s), stripKeywords) + case _ => cleanType(ASTTypeUtil.getNodeType(s), stripKeywords) + case _: IASTIdExpression | _: IASTName | _: IASTDeclarator => + cleanType(ASTTypeUtil.getNodeType(node), stripKeywords) + case s: IASTNamedTypeSpecifier => + cleanType(ASTStringUtil.getReturnTypeString(s, null), stripKeywords) + case s: IASTCompositeTypeSpecifier => + cleanType(ASTStringUtil.getReturnTypeString(s, null), stripKeywords) + case s: IASTEnumerationSpecifier => + cleanType(ASTStringUtil.getReturnTypeString(s, null), stripKeywords) + case s: IASTElaboratedTypeSpecifier => + cleanType(ASTStringUtil.getReturnTypeString(s, null), stripKeywords) + case l: IASTLiteralExpression => + cleanType(ASTTypeUtil.getType(l.getExpressionType)) + case e: IASTExpression => + cleanType(ASTTypeUtil.getNodeType(e), stripKeywords) + case c: ICPPASTConstructorInitializer + if c.getParent.isInstanceOf[ICPPASTConstructorChainInitializer] => + cleanType( + fullName(c.getParent.asInstanceOf[ + ICPPASTConstructorChainInitializer + ].getMemberInitializerId), + stripKeywords + ) + case _ => + cleanType(getNodeSignature(node), stripKeywords) + end match + end typeFor + + protected def notHandledYet(node: IASTNode): Ast = + if !node.isInstanceOf[IASTProblem] && !node.isInstanceOf[IASTProblemHolder] then + val text = notHandledText(node) + logger.debug(text) + Ast(unknownNode(node, nodeSignature(node))) + + protected def nullSafeCode(node: IASTNode): String = + Option(node).map(nodeSignature).getOrElse("") + + protected def nullSafeAst(node: IASTExpression, argIndex: Int): Ast = + val r = nullSafeAst(node) + r.root match + case Some(x: ExpressionNew) => + x.argumentIndex = argIndex + case _ => + r + + protected def nullSafeAst(node: IASTExpression): Ast = + Option(node).map(astForNode).getOrElse(Ast()) + + protected def nullSafeAst(node: IASTStatement, argIndex: Int = -1): Seq[Ast] = + Option(node).map(astsForStatement(_, argIndex)).getOrElse(Seq.empty) + + protected def dereferenceTypeFullName(fullName: String): String = + fullName.replace("*", "") + + protected def isQualifiedName(name: String): Boolean = + name.startsWith(Defines.qualifiedNameSeparator) + + protected def lastNameOfQualifiedName(name: String): String = + val cleanedName = if name.contains("<") && name.contains(">") then + name.substring(0, name.indexOf("<")) + else + name + cleanedName.split(Defines.qualifiedNameSeparator).lastOption.getOrElse(cleanedName) + + protected def functionTypeToSignature(typ: IFunctionType): String = + val returnType = ASTTypeUtil.getType(typ.getReturnType) + val parameterTypes = typ.getParameterTypes.map(ASTTypeUtil.getType) + s"$returnType(${parameterTypes.mkString(",")})" + + protected def fullName(node: IASTNode): String = + node match + case declarator: CPPASTFunctionDeclarator => + declarator.getName.resolveBinding() match + case function: ICPPFunction => + val fullNameNoSig = function.getQualifiedName.mkString(".") + val fn = + if function.isExternC then + function.getName + else + s"$fullNameNoSig:${functionTypeToSignature(function.getType)}" + return fn + case field: ICPPField => + case _: IProblemBinding => + val fullNameNoSig = ASTStringUtil.getQualifiedName(declarator.getName) + val fixedFullName = fixQualifiedName(fullNameNoSig).stripPrefix(".") + if fixedFullName.isEmpty then + return "" + else + return s"$fixedFullName" + case declarator: CASTFunctionDeclarator => + val fn = declarator.getName.toString + return fn + case definition: ICPPASTFunctionDefinition => + return fullName(definition.getDeclarator) + case x => + end match + + val qualifiedName: String = node match + case d: CPPASTIdExpression => + safeGetEvaluation(d) match + case Some(evalBinding: EvalBinding) => + evalBinding.getBinding match + case f: CPPFunction if f.getDeclarations != null => + f.getDeclarations.headOption.map(n => s"${fullName(n)}").getOrElse(f.getName) + case f: CPPFunction if f.getDefinition != null => + s"${fullName(f.getDefinition)}" + case other => + other.getName + case _ => ASTStringUtil.getSimpleName(d.getName) + + case alias: ICPPASTNamespaceAlias => alias.getMappingName.toString + case namespace: ICPPASTNamespaceDefinition + if ASTStringUtil.getSimpleName(namespace.getName).nonEmpty => + s"${fullName(namespace.getParent)}.${ASTStringUtil.getSimpleName(namespace.getName)}" + case namespace: ICPPASTNamespaceDefinition + if ASTStringUtil.getSimpleName(namespace.getName).isEmpty => + s"${fullName(namespace.getParent)}.${uniqueName("namespace", "", "")._1}" + case compType: IASTCompositeTypeSpecifier + if ASTStringUtil.getSimpleName(compType.getName).nonEmpty => + s"${fullName(compType.getParent)}.${ASTStringUtil.getSimpleName(compType.getName)}" + case compType: IASTCompositeTypeSpecifier + if ASTStringUtil.getSimpleName(compType.getName).isEmpty => + val name = compType.getParent match + case decl: IASTSimpleDeclaration => + decl.getDeclarators.headOption + .map(n => ASTStringUtil.getSimpleName(n.getName)) + .getOrElse(uniqueName("composite_type", "", "")._1) + case _ => uniqueName("composite_type", "", "")._1 + s"${fullName(compType.getParent)}.$name" + case enumSpecifier: IASTEnumerationSpecifier => + s"${fullName(enumSpecifier.getParent)}.${ASTStringUtil.getSimpleName(enumSpecifier.getName)}" + case f: ICPPASTLambdaExpression => + s"${fullName(f.getParent)}." + case f: IASTFunctionDeclarator + if ASTStringUtil.getSimpleName(f.getName).isEmpty && f.getNestedDeclarator != null => + s"${fullName(f.getParent)}.${shortName(f.getNestedDeclarator)}" + case f: IASTFunctionDeclarator if f.getParent.isInstanceOf[IASTFunctionDefinition] => + s"${fullName(f.getParent)}" + case f: IASTFunctionDeclarator => + s"${fullName(f.getParent)}.${ASTStringUtil.getSimpleName(f.getName)}" + case f: IASTFunctionDefinition if f.getDeclarator != null => + s"${fullName(f.getParent)}.${ASTStringUtil.getQualifiedName(f.getDeclarator.getName)}" + case f: IASTFunctionDefinition => + s"${fullName(f.getParent)}.${shortName(f)}" + case e: IASTElaboratedTypeSpecifier => + s"${fullName(e.getParent)}.${ASTStringUtil.getSimpleName(e.getName)}" + case d: IASTIdExpression => ASTStringUtil.getSimpleName(d.getName) + case _: IASTTranslationUnit => "" + case u: IASTUnaryExpression => code(u.getOperand) + case x: ICPPASTQualifiedName => ASTStringUtil.getQualifiedName(x) + case other if other != null && other.getParent != null => fullName(other.getParent) + case other if other != null => notHandledYet(other); "" + case null => "" + fixQualifiedName(qualifiedName).stripPrefix(".") + end fullName + + protected def shortName(node: IASTNode): String = + val name = node match + case d: IASTDeclarator + if ASTStringUtil.getSimpleName( + d.getName + ).isEmpty && d.getNestedDeclarator != null => + shortName(d.getNestedDeclarator) + case d: IASTDeclarator => ASTStringUtil.getSimpleName(d.getName) + case f: ICPPASTFunctionDefinition + if ASTStringUtil + .getSimpleName(f.getDeclarator.getName) + .isEmpty && f.getDeclarator.getNestedDeclarator != null => + shortName(f.getDeclarator.getNestedDeclarator) + case f: ICPPASTFunctionDefinition => + lastNameOfQualifiedName(ASTStringUtil.getSimpleName(f.getDeclarator.getName)) + case f: IASTFunctionDefinition + if ASTStringUtil + .getSimpleName(f.getDeclarator.getName) + .isEmpty && f.getDeclarator.getNestedDeclarator != null => + shortName(f.getDeclarator.getNestedDeclarator) + case f: IASTFunctionDefinition => ASTStringUtil.getSimpleName(f.getDeclarator.getName) + case d: CPPASTIdExpression if d.getEvaluation.isInstanceOf[EvalBinding] => + val evaluation = d.getEvaluation.asInstanceOf[EvalBinding] + evaluation.getBinding match + case f: CPPFunction if f.getDeclarations != null => + f.getDeclarations.headOption.map(n => + ASTStringUtil.getSimpleName(n.getName) + ).getOrElse(f.getName) + case f: CPPFunction if f.getDefinition != null => + ASTStringUtil.getSimpleName(f.getDefinition.getName) + case other => + other.getName + case d: IASTIdExpression => + lastNameOfQualifiedName(ASTStringUtil.getSimpleName(d.getName)) + case u: IASTUnaryExpression => shortName(u.getOperand) + case c: IASTFunctionCallExpression => shortName(c.getFunctionNameExpression) + case s: IASTSimpleDeclSpecifier => s.getRawSignature + case e: IASTEnumerationSpecifier => ASTStringUtil.getSimpleName(e.getName) + case c: IASTCompositeTypeSpecifier => ASTStringUtil.getSimpleName(c.getName) + case e: IASTElaboratedTypeSpecifier => ASTStringUtil.getSimpleName(e.getName) + case s: IASTNamedTypeSpecifier => ASTStringUtil.getSimpleName(s.getName) + case other => notHandledYet(other); "" + name + end shortName + + protected def astsForDependenciesAndImports(iASTTranslationUnit: IASTTranslationUnit): Seq[Ast] = + val allIncludes = iASTTranslationUnit.getIncludeDirectives.toList.filterNot(isIncludedNode) + allIncludes.map { include => + val name = include.getName.toString + val _dependencyNode = newDependencyNode(name, name, IncludeKeyword) + val importNode = newImportNode(nodeSignature(include), name, name, include) + diffGraph.addNode(_dependencyNode) + diffGraph.addEdge(importNode, _dependencyNode, EdgeTypes.IMPORTS) + Ast(importNode) + } + + protected def isIncludedNode(node: IASTNode): Boolean = fileName(node) != filename + + protected def astsForComments(iASTTranslationUnit: IASTTranslationUnit): Seq[Ast] = + if config.includeComments then + iASTTranslationUnit.getComments.toList.filterNot(isIncludedNode).map(comment => + astForComment(comment) ) - - private def safeGetEvaluation(expr: ICPPASTExpression): Option[ICPPEvaluation] = - // In case of unresolved includes etc. this may fail throwing an unrecoverable exception - Try(expr.getEvaluation).toOption - - private def notHandledText(node: IASTNode): String = - s"""Node '${node.getClass.getSimpleName}' not handled yet! + else + Seq.empty + + protected def astForNode(node: IASTNode): Ast = + if config.includeFunctionBodies then astForNodeFull(node) else astForNodePartial(node) + + protected def astForNodeFull(node: IASTNode): Ast = + node match + case expr: IASTExpression => astForExpression(expr) + case name: IASTName => astForIdentifier(name) + case decl: IASTDeclSpecifier => astForIdentifier(decl) + case l: IASTInitializerList => astForInitializerList(l) + case c: ICPPASTConstructorInitializer => astForCPPASTConstructorInitializer(c) + case d: ICASTDesignatedInitializer => astForCASTDesignatedInitializer(d) + case d: ICPPASTDesignatedInitializer => astForCPPASTDesignatedInitializer(d) + case d: CASTArrayRangeDesignator => astForCASTArrayRangeDesignator(d) + case d: CPPASTArrayRangeDesignator => astForCPPASTArrayRangeDesignator(d) + case d: ICASTArrayDesignator => nullSafeAst(d.getSubscriptExpression) + case d: ICPPASTArrayDesignator => nullSafeAst(d.getSubscriptExpression) + case d: ICPPASTFieldDesignator => astForNode(d.getName) + case d: ICASTFieldDesignator => astForNode(d.getName) + case decl: ICPPASTDecltypeSpecifier => astForDecltypeSpecifier(decl) + case arrMod: IASTArrayModifier => astForArrayModifier(arrMod) + case _ => notHandledYet(node) + + protected def astForNodePartial(node: IASTNode): Ast = + node match + case expr: IASTExpression => astForExpression(expr) + case name: IASTName => astForIdentifier(name) + case decl: IASTDeclSpecifier => astForIdentifier(decl) + case decl: ICPPASTDecltypeSpecifier => astForDecltypeSpecifier(decl) + case _ => notHandledYet(node) + + protected def typeForDeclSpecifier( + spec: IASTNode, + stripKeywords: Boolean = true, + index: Int = 0 + ): String = + val tpe = spec match + case s: IASTSimpleDeclSpecifier if s.getParent.isInstanceOf[IASTParameterDeclaration] => + val parentDecl = s.getParent.asInstanceOf[IASTParameterDeclaration].getDeclarator + pointersAsString(s, parentDecl, stripKeywords) + case s: IASTSimpleDeclSpecifier if s.getParent.isInstanceOf[IASTFunctionDefinition] => + val parentDecl = s.getParent.asInstanceOf[IASTFunctionDefinition].getDeclarator + ASTStringUtil.getReturnTypeString(s, parentDecl) + case s: IASTSimpleDeclaration if s.getParent.isInstanceOf[ICASTKnRFunctionDeclarator] => + val decl = s.getDeclarators.toList(index) + pointersAsString(s.getDeclSpecifier, decl, stripKeywords) + case s: IASTSimpleDeclSpecifier if s.getParent.isInstanceOf[IASTSimpleDeclaration] => + val parentDecl = + s.getParent.asInstanceOf[IASTSimpleDeclaration].getDeclarators.toList(index) + pointersAsString(s, parentDecl, stripKeywords) + case s: IASTSimpleDeclSpecifier => + ASTStringUtil.getReturnTypeString(s, null) + case s: IASTNamedTypeSpecifier if s.getParent.isInstanceOf[IASTParameterDeclaration] => + val parentDecl = s.getParent.asInstanceOf[IASTParameterDeclaration].getDeclarator + pointersAsString(s, parentDecl, stripKeywords) + case s: IASTNamedTypeSpecifier if s.getParent.isInstanceOf[IASTSimpleDeclaration] => + val parentDecl = + s.getParent.asInstanceOf[IASTSimpleDeclaration].getDeclarators.toList(index) + pointersAsString(s, parentDecl, stripKeywords) + case s: IASTNamedTypeSpecifier => ASTStringUtil.getSimpleName(s.getName) + case s: IASTCompositeTypeSpecifier if s.getParent.isInstanceOf[IASTSimpleDeclaration] => + val parentDecl = + s.getParent.asInstanceOf[IASTSimpleDeclaration].getDeclarators.toList(index) + pointersAsString(s, parentDecl, stripKeywords) + case s: IASTCompositeTypeSpecifier => ASTStringUtil.getSimpleName(s.getName) + case s: IASTEnumerationSpecifier if s.getParent.isInstanceOf[IASTSimpleDeclaration] => + val parentDecl = + s.getParent.asInstanceOf[IASTSimpleDeclaration].getDeclarators.toList(index) + pointersAsString(s, parentDecl, stripKeywords) + case s: IASTEnumerationSpecifier => ASTStringUtil.getSimpleName(s.getName) + case s: IASTElaboratedTypeSpecifier + if s.getParent.isInstanceOf[IASTParameterDeclaration] => + val parentDecl = s.getParent.asInstanceOf[IASTParameterDeclaration].getDeclarator + pointersAsString(s, parentDecl, stripKeywords) + case s: IASTElaboratedTypeSpecifier + if s.getParent.isInstanceOf[IASTSimpleDeclaration] => + val parentDecl = + s.getParent.asInstanceOf[IASTSimpleDeclaration].getDeclarators.toList(index) + pointersAsString(s, parentDecl, stripKeywords) + case s: IASTElaboratedTypeSpecifier => ASTStringUtil.getSignatureString(s, null) + // TODO: handle other types of IASTDeclSpecifier + case _ => Defines.anyTypeName + if tpe.isEmpty then Defines.anyTypeName else tpe + end typeForDeclSpecifier + + private def safeGetEvaluation(expr: ICPPASTExpression): Option[ICPPEvaluation] = + // In case of unresolved includes etc. this may fail throwing an unrecoverable exception + Try(expr.getEvaluation).toOption + + protected def safeGetType(tpe: IType): String = + // In case of unresolved includes etc. this may fail throwing an unrecoverable exception + Try(ASTTypeUtil.getType(tpe)).getOrElse(Defines.anyTypeName) + + private def notHandledText(node: IASTNode): String = + s"""Node '${node.getClass.getSimpleName}' not handled yet! | Code: '${node.getRawSignature}' | File: '$filename' | Line: ${line(node).getOrElse(-1)} | """.stripMargin - private def pointersAsString( - spec: IASTDeclSpecifier, - parentDecl: IASTDeclarator, - stripKeywords: Boolean - ): String = - val tpe = typeFor(spec, stripKeywords) - val pointers = parentDecl.getPointerOperators - val arr = parentDecl match - case p: IASTArrayDeclarator => - p.getArrayModifiers.toList.map(_.getRawSignature).mkString - case _ => "" - if pointers.isEmpty then s"$tpe$arr" - else - val refs = - "*" * (pointers.length - pointers.count(_.isInstanceOf[ICPPASTReferenceOperator])) - s"$tpe$arr$refs".strip() - - private def astForDecltypeSpecifier(decl: ICPPASTDecltypeSpecifier): Ast = - val op = ".typeOf" - val cpgUnary = callNode(decl, nodeSignature(decl), op, op, DispatchTypes.STATIC_DISPATCH) - val operand = nullSafeAst(decl.getDecltypeExpression) - callAst(cpgUnary, List(operand)) - - private def astForCASTDesignatedInitializer(d: ICASTDesignatedInitializer): Ast = - val node = blockNode(d, Defines.empty, Defines.voidTypeName) - scope.pushNewScope(node) - val op = Operators.assignment - val calls = withIndex(d.getDesignators) { (des, o) => - val callNode_ = - callNode(d, nodeSignature(d), op, op, DispatchTypes.STATIC_DISPATCH) - .argumentIndex(o) - val left = astForNode(des) - val right = astForNode(d.getOperand) - callAst(callNode_, List(left, right)) - } - scope.popScope() - blockAst(node, calls.toList) - - private def astForCPPASTDesignatedInitializer(d: ICPPASTDesignatedInitializer): Ast = - val node = blockNode(d, Defines.empty, Defines.voidTypeName) - scope.pushNewScope(node) - val op = Operators.assignment - val calls = withIndex(d.getDesignators) { (des, o) => - val callNode_ = - callNode(d, nodeSignature(d), op, op, DispatchTypes.STATIC_DISPATCH) - .argumentIndex(o) - val left = astForNode(des) - val right = astForNode(d.getOperand) - callAst(callNode_, List(left, right)) - } - scope.popScope() - blockAst(node, calls.toList) - - private def astForCPPASTConstructorInitializer(c: ICPPASTConstructorInitializer): Ast = - val name = ".constructorInitializer" - val callNode_ = - callNode(c, nodeSignature(c), name, name, DispatchTypes.STATIC_DISPATCH) - val args = c.getArguments.toList.map(a => astForNode(a)) - callAst(callNode_, args) - - private def astForCASTArrayRangeDesignator(des: CASTArrayRangeDesignator): Ast = - val op = Operators.arrayInitializer - val callNode_ = callNode(des, nodeSignature(des), op, op, DispatchTypes.STATIC_DISPATCH) - val floorAst = nullSafeAst(des.getRangeFloor) - val ceilingAst = nullSafeAst(des.getRangeCeiling) - callAst(callNode_, List(floorAst, ceilingAst)) - - private def astForCPPASTArrayRangeDesignator(des: CPPASTArrayRangeDesignator): Ast = - val op = Operators.arrayInitializer - val callNode_ = callNode(des, nodeSignature(des), op, op, DispatchTypes.STATIC_DISPATCH) - val floorAst = nullSafeAst(des.getRangeFloor) - val ceilingAst = nullSafeAst(des.getRangeCeiling) - callAst(callNode_, List(floorAst, ceilingAst)) + private def pointersAsString( + spec: IASTDeclSpecifier, + parentDecl: IASTDeclarator, + stripKeywords: Boolean + ): String = + val tpe = typeFor(spec, stripKeywords) + val pointers = parentDecl.getPointerOperators + val arr = parentDecl match + case p: IASTArrayDeclarator => + p.getArrayModifiers.toList.map(_.getRawSignature).mkString + case _ => "" + if pointers.isEmpty then s"$tpe$arr" + else + val refs = + "*" * (pointers.length - pointers.count(_.isInstanceOf[ICPPASTReferenceOperator])) + s"$tpe$arr$refs".strip() + + private def astForDecltypeSpecifier(decl: ICPPASTDecltypeSpecifier): Ast = + val op = ".typeOf" + val cpgUnary = callNode(decl, nodeSignature(decl), op, op, DispatchTypes.STATIC_DISPATCH) + val operand = nullSafeAst(decl.getDecltypeExpression) + callAst(cpgUnary, List(operand)) + + private def astForCASTDesignatedInitializer(d: ICASTDesignatedInitializer): Ast = + val node = blockNode(d, Defines.empty, Defines.voidTypeName) + scope.pushNewScope(node) + val op = Operators.assignment + val calls = withIndex(d.getDesignators) { (des, o) => + val callNode_ = + callNode(d, nodeSignature(d), op, op, DispatchTypes.STATIC_DISPATCH) + .argumentIndex(o) + val left = astForNode(des) + val right = astForNode(d.getOperand) + callAst(callNode_, List(left, right)) + } + scope.popScope() + blockAst(node, calls.toList) + + private def astForCPPASTDesignatedInitializer(d: ICPPASTDesignatedInitializer): Ast = + val node = blockNode(d, Defines.empty, Defines.voidTypeName) + scope.pushNewScope(node) + val op = Operators.assignment + val calls = withIndex(d.getDesignators) { (des, o) => + val callNode_ = + callNode(d, nodeSignature(d), op, op, DispatchTypes.STATIC_DISPATCH) + .argumentIndex(o) + val left = astForNode(des) + val right = astForNode(d.getOperand) + callAst(callNode_, List(left, right)) + } + scope.popScope() + blockAst(node, calls.toList) + + private def astForCPPASTConstructorInitializer(c: ICPPASTConstructorInitializer): Ast = + val name = ".constructorInitializer" + val callNode_ = + callNode(c, nodeSignature(c), name, name, DispatchTypes.STATIC_DISPATCH) + val args = c.getArguments.toList.map(a => astForNode(a)) + callAst(callNode_, args) + + private def astForCASTArrayRangeDesignator(des: CASTArrayRangeDesignator): Ast = + val op = Operators.arrayInitializer + val callNode_ = callNode(des, nodeSignature(des), op, op, DispatchTypes.STATIC_DISPATCH) + val floorAst = nullSafeAst(des.getRangeFloor) + val ceilingAst = nullSafeAst(des.getRangeCeiling) + callAst(callNode_, List(floorAst, ceilingAst)) + + private def astForCPPASTArrayRangeDesignator(des: CPPASTArrayRangeDesignator): Ast = + val op = Operators.arrayInitializer + val callNode_ = callNode(des, nodeSignature(des), op, op, DispatchTypes.STATIC_DISPATCH) + val floorAst = nullSafeAst(des.getRangeFloor) + val ceilingAst = nullSafeAst(des.getRangeCeiling) + callAst(callNode_, List(floorAst, ceilingAst)) end AstCreatorHelper diff --git a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/astcreation/AstForExpressionsCreator.scala b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/astcreation/AstForExpressionsCreator.scala index 9ae0790d..0f5e7939 100644 --- a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/astcreation/AstForExpressionsCreator.scala +++ b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/astcreation/AstForExpressionsCreator.scala @@ -3,377 +3,548 @@ package io.appthreat.c2cpg.astcreation import io.shiftleft.codepropertygraph.generated.nodes.{NewCall, NewIdentifier, NewMethodRef} import io.shiftleft.codepropertygraph.generated.{DispatchTypes, Operators} import io.appthreat.x2cpg.{Ast, ValidationMode} +import io.appthreat.x2cpg.Defines as X2CpgDefines +import org.eclipse.cdt.core.dom.ast import org.eclipse.cdt.core.dom.ast.* import org.eclipse.cdt.core.dom.ast.cpp.* import org.eclipse.cdt.core.dom.ast.gnu.IGNUASTCompoundStatementExpression -import org.eclipse.cdt.internal.core.dom.parser.cpp.CPPASTQualifiedName +import org.eclipse.cdt.core.model.IMethod +import org.eclipse.cdt.internal.core.dom.parser.c.{ + CASTFieldReference, + CASTFunctionCallExpression, + CASTIdExpression, + CBasicType, + CFunctionType, + CPointerType +} +import org.eclipse.cdt.internal.core.dom.parser.cpp.semantics.{EvalBinding, EvalFunctionCall} +import org.eclipse.cdt.internal.core.dom.parser.cpp.{ + CPPASTIdExpression, + CPPASTQualifiedName, + CPPClosureType, + CPPField, + CPPFunction, + CPPFunctionType +} trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode): - this: AstCreator => - - private def astForBinaryExpression(bin: IASTBinaryExpression): Ast = - val op = bin.getOperator match - case IASTBinaryExpression.op_multiply => Operators.multiplication - case IASTBinaryExpression.op_divide => Operators.division - case IASTBinaryExpression.op_modulo => Operators.modulo - case IASTBinaryExpression.op_plus => Operators.addition - case IASTBinaryExpression.op_minus => Operators.subtraction - case IASTBinaryExpression.op_shiftLeft => Operators.shiftLeft - case IASTBinaryExpression.op_shiftRight => Operators.arithmeticShiftRight - case IASTBinaryExpression.op_lessThan => Operators.lessThan - case IASTBinaryExpression.op_greaterThan => Operators.greaterThan - case IASTBinaryExpression.op_lessEqual => Operators.lessEqualsThan - case IASTBinaryExpression.op_greaterEqual => Operators.greaterEqualsThan - case IASTBinaryExpression.op_binaryAnd => Operators.and - case IASTBinaryExpression.op_binaryXor => Operators.xor - case IASTBinaryExpression.op_binaryOr => Operators.or - case IASTBinaryExpression.op_logicalAnd => Operators.logicalAnd - case IASTBinaryExpression.op_logicalOr => Operators.logicalOr - case IASTBinaryExpression.op_assign => Operators.assignment - case IASTBinaryExpression.op_multiplyAssign => Operators.assignmentMultiplication - case IASTBinaryExpression.op_divideAssign => Operators.assignmentDivision - case IASTBinaryExpression.op_moduloAssign => Operators.assignmentModulo - case IASTBinaryExpression.op_plusAssign => Operators.assignmentPlus - case IASTBinaryExpression.op_minusAssign => Operators.assignmentMinus - case IASTBinaryExpression.op_shiftLeftAssign => Operators.assignmentShiftLeft - case IASTBinaryExpression.op_shiftRightAssign => - Operators.assignmentArithmeticShiftRight - case IASTBinaryExpression.op_binaryAndAssign => Operators.assignmentAnd - case IASTBinaryExpression.op_binaryXorAssign => Operators.assignmentXor - case IASTBinaryExpression.op_binaryOrAssign => Operators.assignmentOr - case IASTBinaryExpression.op_equals => Operators.equals - case IASTBinaryExpression.op_notequals => Operators.notEquals - case IASTBinaryExpression.op_pmdot => Operators.indirectFieldAccess - case IASTBinaryExpression.op_pmarrow => Operators.indirectFieldAccess - case IASTBinaryExpression.op_max => ".max" - case IASTBinaryExpression.op_min => ".min" - case IASTBinaryExpression.op_ellipses => ".op_ellipses" - case _ => ".unknown" - - val callNode_ = callNode( - bin, - nodeSignature(bin), - op, - op, - if op == Operators.indirectFieldAccess then DispatchTypes.DYNAMIC_DISPATCH - else DispatchTypes.STATIC_DISPATCH - ) - val left = nullSafeAst(bin.getOperand1) - val right = nullSafeAst(bin.getOperand2) - callAst(callNode_, List(left, right)) - end astForBinaryExpression - - private def astForExpressionList(exprList: IASTExpressionList): Ast = - val name = ".expressionList" - val callNode_ = - callNode(exprList, nodeSignature(exprList), name, name, DispatchTypes.STATIC_DISPATCH) - val childAsts = exprList.getExpressions.map(nullSafeAst) - callAst(callNode_, childAsts.toIndexedSeq) - - private def astForCallExpression(call: IASTFunctionCallExpression): Ast = - val rec = call.getFunctionNameExpression match - case unaryExpression: IASTUnaryExpression - if unaryExpression.getOperand.isInstanceOf[IASTBinaryExpression] => - astForBinaryExpression( - unaryExpression.getOperand.asInstanceOf[IASTBinaryExpression] + this: AstCreator => + + protected def astForExpression(expression: IASTExpression): Ast = + val r = expression match + case lit: IASTLiteralExpression => astForLiteral(lit) + case un: IASTUnaryExpression => astForUnaryExpression(un) + case bin: IASTBinaryExpression => astForBinaryExpression(bin) + case exprList: IASTExpressionList => astForExpressionList(exprList) + case idExpr: IASTIdExpression => astForIdExpression(idExpr) + case call: IASTFunctionCallExpression => astForCallExpression(call) + case typeId: IASTTypeIdExpression => astForTypeIdExpression(typeId) + case fieldRef: IASTFieldReference => astForFieldReference(fieldRef) + case expr: IASTConditionalExpression => astForConditionalExpression(expr) + case arr: IASTArraySubscriptExpression => astForArrayIndexExpression(arr) + case castExpression: IASTCastExpression => astForCastExpression(castExpression) + case newExpression: ICPPASTNewExpression => astForNewExpression(newExpression) + case delExpression: ICPPASTDeleteExpression => astForDeleteExpression(delExpression) + case typeIdInit: IASTTypeIdInitializerExpression => astForTypeIdInitExpression(typeIdInit) + case c: ICPPASTSimpleTypeConstructorExpression => astForConstructorExpression(c) + case lambdaExpression: ICPPASTLambdaExpression => astForMethodRefForLambda(lambdaExpression) + case cExpr: IGNUASTCompoundStatementExpression => astForCompoundStatementExpression(cExpr) + case pExpr: ICPPASTPackExpansionExpression => astForPackExpansionExpression(pExpr) + case _ => notHandledYet(expression) + asChildOfMacroCall(expression, r) + end astForExpression + + protected def astForStaticAssert(a: ICPPASTStaticAssertDeclaration): Ast = + val name = "static_assert" + val call = callNode(a, code(a), name, name, DispatchTypes.STATIC_DISPATCH) + val cond = nullSafeAst(a.getCondition) + val messg = nullSafeAst(a.getMessage) + callAst(call, List(cond, messg)) + + private def astForBinaryExpression(bin: IASTBinaryExpression): Ast = + val op = bin.getOperator match + case IASTBinaryExpression.op_multiply => Operators.multiplication + case IASTBinaryExpression.op_divide => Operators.division + case IASTBinaryExpression.op_modulo => Operators.modulo + case IASTBinaryExpression.op_plus => Operators.addition + case IASTBinaryExpression.op_minus => Operators.subtraction + case IASTBinaryExpression.op_shiftLeft => Operators.shiftLeft + case IASTBinaryExpression.op_shiftRight => Operators.arithmeticShiftRight + case IASTBinaryExpression.op_lessThan => Operators.lessThan + case IASTBinaryExpression.op_greaterThan => Operators.greaterThan + case IASTBinaryExpression.op_lessEqual => Operators.lessEqualsThan + case IASTBinaryExpression.op_greaterEqual => Operators.greaterEqualsThan + case IASTBinaryExpression.op_binaryAnd => Operators.and + case IASTBinaryExpression.op_binaryXor => Operators.xor + case IASTBinaryExpression.op_binaryOr => Operators.or + case IASTBinaryExpression.op_logicalAnd => Operators.logicalAnd + case IASTBinaryExpression.op_logicalOr => Operators.logicalOr + case IASTBinaryExpression.op_assign => Operators.assignment + case IASTBinaryExpression.op_multiplyAssign => Operators.assignmentMultiplication + case IASTBinaryExpression.op_divideAssign => Operators.assignmentDivision + case IASTBinaryExpression.op_moduloAssign => Operators.assignmentModulo + case IASTBinaryExpression.op_plusAssign => Operators.assignmentPlus + case IASTBinaryExpression.op_minusAssign => Operators.assignmentMinus + case IASTBinaryExpression.op_shiftLeftAssign => Operators.assignmentShiftLeft + case IASTBinaryExpression.op_shiftRightAssign => Operators.assignmentArithmeticShiftRight + case IASTBinaryExpression.op_binaryAndAssign => Operators.assignmentAnd + case IASTBinaryExpression.op_binaryXorAssign => Operators.assignmentXor + case IASTBinaryExpression.op_binaryOrAssign => Operators.assignmentOr + case IASTBinaryExpression.op_equals => Operators.equals + case IASTBinaryExpression.op_notequals => Operators.notEquals + case IASTBinaryExpression.op_pmdot => Operators.indirectFieldAccess + case IASTBinaryExpression.op_pmarrow => Operators.indirectFieldAccess + case IASTBinaryExpression.op_max => ".max" + case IASTBinaryExpression.op_min => ".min" + case IASTBinaryExpression.op_ellipses => ".op_ellipses" + case _ => ".unknown" + + val callNode_ = callNode(bin, code(bin), op, op, DispatchTypes.STATIC_DISPATCH) + val left = nullSafeAst(bin.getOperand1) + val right = nullSafeAst(bin.getOperand2) + callAst(callNode_, List(left, right)) + end astForBinaryExpression + + private def astForExpressionList(exprList: IASTExpressionList): Ast = + val name = ".expressionList" + val callNode_ = + callNode(exprList, code(exprList), name, name, DispatchTypes.STATIC_DISPATCH) + val childAsts = exprList.getExpressions.map(nullSafeAst) + callAst(callNode_, childAsts.toIndexedSeq) + + private def astForCppCallExpression(call: ICPPASTFunctionCallExpression): Ast = + val functionNameExpr = call.getFunctionNameExpression + val typ = functionNameExpr.getExpressionType + typ match + case pointerType: IPointerType => + createPointerCallAst(call, cleanType(ASTTypeUtil.getType(call.getExpressionType))) + case functionType: ICPPFunctionType => + functionNameExpr match + case idExpr: CPPASTIdExpression => + val function = idExpr.getName.getBinding.asInstanceOf[ICPPFunction] + val name = idExpr.getName.getLastName.toString + val signature = + if function.isExternC then + "" + else + functionTypeToSignature(functionType) + + val fullName = + if function.isExternC then + name + else + val fullNameNoSig = function.getQualifiedName.mkString(".") + s"$fullNameNoSig:$signature" + + val dispatchType = DispatchTypes.STATIC_DISPATCH + + val callCpgNode = callNode( + call, + code(call), + name, + fullName, + dispatchType, + Some(signature), + Some(registerType(cleanType(safeGetType(call.getExpressionType)))) + ) + val args = call.getArguments.toList.map(a => astForNode(a)) + + createCallAst(callCpgNode, args) + case fieldRefExpr: ICPPASTFieldReference => + val instanceAst = astForExpression(fieldRefExpr.getFieldOwner) + val args = call.getArguments.toList.map(a => astForNode(a)) + + // TODO This wont do if the name is a reference. + val name = fieldRefExpr.getFieldName.toString + val signature = functionTypeToSignature(functionType) + + val classFullName = cleanType(ASTTypeUtil.getType(fieldRefExpr.getFieldOwnerType)) + val fullName = s"$classFullName.$name:$signature" + + fieldRefExpr.getFieldName.resolveBinding() + val method = fieldRefExpr.getFieldName.getBinding().asInstanceOf[ICPPMethod] + val (dispatchType, receiver) = + if method.isVirtual || method.isPureVirtual then + (DispatchTypes.DYNAMIC_DISPATCH, Some(instanceAst)) + else + (DispatchTypes.STATIC_DISPATCH, None) + val callCpgNode = callNode( + call, + code(call), + name, + fullName, + dispatchType, + Some(signature), + Some(cleanType(ASTTypeUtil.getType(call.getExpressionType))) ) - case unaryExpression: IASTUnaryExpression - if unaryExpression.getOperand.isInstanceOf[IASTFieldReference] => - astForFieldReference(unaryExpression.getOperand.asInstanceOf[IASTFieldReference]) - case unaryExpression: IASTUnaryExpression - if unaryExpression.getOperand.isInstanceOf[IASTArraySubscriptExpression] => - astForArrayIndexExpression( - unaryExpression.getOperand.asInstanceOf[IASTArraySubscriptExpression] + + createCallAst(callCpgNode, args, base = Some(instanceAst), receiver) + case classType: ICPPClassType => + val evaluation = call.getEvaluation.asInstanceOf[EvalFunctionCall] + val functionType = evaluation.getOverload.getType + val signature = functionTypeToSignature(functionType) + val name = "()" + + classType match + case closureType: CPPClosureType => + val fullName = s"$name:$signature" + val dispatchType = DispatchTypes.DYNAMIC_DISPATCH + + val callCpgNode = callNode( + call, + code(call), + name, + fullName, + dispatchType, + Some(signature), + Some(cleanType(ASTTypeUtil.getType(call.getExpressionType))) + ) + + val receiverAst = astForExpression(functionNameExpr) + val args = call.getArguments.toList.map(a => astForNode(a)) + + createCallAst(callCpgNode, args, receiver = Some(receiverAst)) + case _ => + val classFullName = cleanType(ASTTypeUtil.getType(classType)) + val fullName = s"$classFullName.$name:$signature" + + val method = evaluation.getOverload.asInstanceOf[ICPPMethod] + val dispatchType = + if method.isVirtual || method.isPureVirtual then + DispatchTypes.DYNAMIC_DISPATCH + else + DispatchTypes.STATIC_DISPATCH + + val callCpgNode = callNode( + call, + code(call), + name, + fullName, + dispatchType, + Some(signature), + Some(cleanType(ASTTypeUtil.getType(call.getExpressionType))) ) - case unaryExpression: IASTUnaryExpression - if unaryExpression.getOperand.isInstanceOf[IASTConditionalExpression] => - astForUnaryExpression(unaryExpression) - case unaryExpression: IASTUnaryExpression - if unaryExpression.getOperand.isInstanceOf[IASTUnaryExpression] => - astForUnaryExpression(unaryExpression.getOperand.asInstanceOf[IASTUnaryExpression]) - case lambdaExpression: ICPPASTLambdaExpression => - astForMethodRefForLambda(lambdaExpression) - case other => astForExpression(other) - - val (dd, name) = call.getFunctionNameExpression match - case _: ICPPASTLambdaExpression => - ( - DispatchTypes.DYNAMIC_DISPATCH, - rec.root.get.asInstanceOf[NewMethodRef].methodFullName + + val instanceAst = astForExpression(functionNameExpr) + val args = call.getArguments.toList.map(a => astForNode(a)) + + createCallAst( + callCpgNode, + args, + base = Some(instanceAst), + receiver = Some(instanceAst) ) - case _ if rec.root.exists(_.isInstanceOf[NewIdentifier]) => - (DispatchTypes.STATIC_DISPATCH, rec.root.get.asInstanceOf[NewIdentifier].name) - case _ - if rec.root.exists(_.isInstanceOf[NewCall]) && call.getFunctionNameExpression - .isInstanceOf[IASTFieldReference] => - ( - DispatchTypes.STATIC_DISPATCH, - nodeSignature( - call.getFunctionNameExpression.asInstanceOf[IASTFieldReference].getFieldName - ) + end match + case _: IProblemType => + astForCppCallExpressionUntyped(call) + case _: IProblemBinding => + astForCppCallExpressionUntyped(call) + end match + end astForCppCallExpression + + private def astForCppCallExpressionUntyped(call: ICPPASTFunctionCallExpression): Ast = + val functionNameExpr = call.getFunctionNameExpression + + functionNameExpr match + case fieldRefExpr: ICPPASTFieldReference => + val instanceAst = astForExpression(fieldRefExpr.getFieldOwner) + val args = call.getArguments.toList.map(a => astForNode(a)) + + val name = fieldRefExpr.getFieldName.toString + val signature = X2CpgDefines.UnresolvedSignature + val fullName = s"${X2CpgDefines.UnresolvedNamespace}.$name:$signature(${args.size})" + + val callCpgNode = callNode( + call, + code(call), + name, + fullName, + DispatchTypes.STATIC_DISPATCH, + Some(signature), + Some(X2CpgDefines.Any) + ) + + createCallAst(callCpgNode, args, base = Some(instanceAst), receiver = Some(instanceAst)) + case idExpr: CPPASTIdExpression => + val args = call.getArguments.toList.map(a => astForNode(a)) + + val name = idExpr.getName.getLastName.toString + val signature = X2CpgDefines.UnresolvedSignature + val fullName = s"${X2CpgDefines.UnresolvedNamespace}.$name:$signature(${args.size})" + + val callCpgNode = callNode( + call, + code(call), + name, + fullName, + DispatchTypes.STATIC_DISPATCH, + Some(signature), + Some(X2CpgDefines.Any) + ) + + createCallAst(callCpgNode, args) + case other => + // This could either be a pointer or an operator() call we dont know at this point + // but since it is CPP we opt for the later. + val args = call.getArguments.toList.map(a => astForNode(a)) + + val name = "()" + val signature = X2CpgDefines.UnresolvedSignature + val fullName = s"${X2CpgDefines.UnresolvedNamespace}.$name:$signature(${args.size})" + + val callCpgNode = callNode( + call, + code(call), + name, + fullName, + DispatchTypes.STATIC_DISPATCH, + Some(signature), + Some(X2CpgDefines.Any) + ) + + val instanceAst = astForExpression(functionNameExpr) + createCallAst(callCpgNode, args, base = Some(instanceAst), receiver = Some(instanceAst)) + end match + end astForCppCallExpressionUntyped + + private def astForCCallExpression(call: CASTFunctionCallExpression): Ast = + val functionNameExpr = call.getFunctionNameExpression + val typ = functionNameExpr.getExpressionType + typ match + case pointerType: CPointerType => + createPointerCallAst(call, cleanType(ASTTypeUtil.getType(call.getExpressionType))) + case functionType: CFunctionType => + functionNameExpr match + case idExpr: CASTIdExpression => + createCFunctionCallAst( + call, + idExpr, + cleanType(ASTTypeUtil.getType(call.getExpressionType)) ) - case _ if rec.root.exists(_.isInstanceOf[NewCall]) => - (DispatchTypes.STATIC_DISPATCH, rec.root.get.asInstanceOf[NewCall].code) - case reference: IASTIdExpression => - (DispatchTypes.STATIC_DISPATCH, nodeSignature(reference)) case _ => - (DispatchTypes.STATIC_DISPATCH, "") - - val shortName = fixQualifiedName(name) - val fullName = typeFor(call.getFunctionNameExpression) match - case t if t != Defines.anyTypeName => s"${dereferenceTypeFullName(t)}.$shortName" - case _ => shortName - val cpgCall = callNode(call, nodeSignature(call), shortName, fullName, dd) - val args = call.getArguments.toList.map(a => astForNode(a)) - rec.root match - // Optimization: do not include the receiver if the receiver is just the function name, - // e.g., for `f(x)`, don't include an `f` identifier node as a first child. Since we - // have so many call sites in CPGs, this drastically reduces the number of nodes. - // Moreover, the data flow tracker does not need to track `f`, which would not make - // much sense anyway. - case Some(r: NewIdentifier) if r.name == shortName => - callAst(cpgCall, args) - case Some(r: NewMethodRef) if r.code == shortName => - callAst(cpgCall, args) - case Some(_) => - callAst(cpgCall, args, Option(rec)) - case None => - callAst(cpgCall, args) - end astForCallExpression - - private def astForUnaryExpression(unary: IASTUnaryExpression): Ast = - val operatorMethod = unary.getOperator match - case IASTUnaryExpression.op_prefixIncr => Operators.preIncrement - case IASTUnaryExpression.op_prefixDecr => Operators.preDecrement - case IASTUnaryExpression.op_plus => Operators.plus - case IASTUnaryExpression.op_minus => Operators.minus - case IASTUnaryExpression.op_star => Operators.indirection - case IASTUnaryExpression.op_amper => Operators.addressOf - case IASTUnaryExpression.op_tilde => Operators.not - case IASTUnaryExpression.op_not => Operators.logicalNot - case IASTUnaryExpression.op_sizeof => Operators.sizeOf - case IASTUnaryExpression.op_postFixIncr => Operators.postIncrement - case IASTUnaryExpression.op_postFixDecr => Operators.postDecrement - case IASTUnaryExpression.op_throw => ".throw" - case IASTUnaryExpression.op_typeid => ".typeOf" - case IASTUnaryExpression.op_bracketedPrimary => ".bracketedPrimary" - case _ => ".unknown" - - if - unary.getOperator == IASTUnaryExpression.op_bracketedPrimary && - !unary.getOperand.isInstanceOf[IASTExpressionList] - then - nullSafeAst(unary.getOperand) - else - val cpgUnary = + createPointerCallAst(call, cleanType(ASTTypeUtil.getType(call.getExpressionType))) + case _ => + astForCCallExpressionUntyped(call) + + private def createCFunctionCallAst( + call: CASTFunctionCallExpression, + idExpr: CASTIdExpression, + callTypeFullName: String + ): Ast = + val name = idExpr.getName.getLastName.toString + val signature = "" + + val dispatchType = DispatchTypes.STATIC_DISPATCH + + val callCpgNode = callNode( + call, + code(call), + name, + name, + dispatchType, + Some(signature), + Some(callTypeFullName) + ) + val args = call.getArguments.toList.map(a => astForNode(a)) + + createCallAst(callCpgNode, args) + end createCFunctionCallAst + + private def createPointerCallAst( + call: IASTFunctionCallExpression, + callTypeFullName: String + ): Ast = + val functionNameExpr = call.getFunctionNameExpression + val name = Defines.operatorPointerCall + val signature = "" + + val callCpgNode = + callNode( + call, + code(call), + name, + name, + DispatchTypes.DYNAMIC_DISPATCH, + Some(signature), + Some(callTypeFullName) + ) + + val args = call.getArguments.toList.map(a => astForNode(a)) + val receiverAst = astForExpression(functionNameExpr) + createCallAst(callCpgNode, args, receiver = Some(receiverAst)) + end createPointerCallAst + + private def astForCCallExpressionUntyped(call: CASTFunctionCallExpression): Ast = + val functionNameExpr = call.getFunctionNameExpression + + functionNameExpr match + case idExpr: CASTIdExpression => + createCFunctionCallAst(call, idExpr, X2CpgDefines.Any) + case _ => + createPointerCallAst(call, X2CpgDefines.Any) + + private def astForCallExpression(call: IASTFunctionCallExpression): Ast = + call match + case cppCall: ICPPASTFunctionCallExpression => + astForCppCallExpression(cppCall) + case cCall: CASTFunctionCallExpression => + astForCCallExpression(cCall) + + private def astForUnaryExpression(unary: IASTUnaryExpression): Ast = + val operatorMethod = unary.getOperator match + case IASTUnaryExpression.op_prefixIncr => Operators.preIncrement + case IASTUnaryExpression.op_prefixDecr => Operators.preDecrement + case IASTUnaryExpression.op_plus => Operators.plus + case IASTUnaryExpression.op_minus => Operators.minus + case IASTUnaryExpression.op_star => Operators.indirection + case IASTUnaryExpression.op_amper => Operators.addressOf + case IASTUnaryExpression.op_tilde => Operators.not + case IASTUnaryExpression.op_not => Operators.logicalNot + case IASTUnaryExpression.op_sizeof => Operators.sizeOf + case IASTUnaryExpression.op_postFixIncr => Operators.postIncrement + case IASTUnaryExpression.op_postFixDecr => Operators.postDecrement + case IASTUnaryExpression.op_throw => ".throw" + case IASTUnaryExpression.op_typeid => ".typeOf" + case IASTUnaryExpression.op_bracketedPrimary => ".bracketedPrimary" + case _ => ".unknown" + + if + unary.getOperator == IASTUnaryExpression.op_bracketedPrimary && + !unary.getOperand.isInstanceOf[IASTExpressionList] + then + nullSafeAst(unary.getOperand) + else + val cpgUnary = + callNode( + unary, + code(unary), + operatorMethod, + operatorMethod, + DispatchTypes.STATIC_DISPATCH + ) + val operand = nullSafeAst(unary.getOperand) + callAst(cpgUnary, List(operand)) + end astForUnaryExpression + + private def astForTypeIdExpression(typeId: IASTTypeIdExpression): Ast = + typeId.getOperator match + case op + if op == IASTTypeIdExpression.op_sizeof || + op == IASTTypeIdExpression.op_sizeofParameterPack || + op == IASTTypeIdExpression.op_typeid || + op == IASTTypeIdExpression.op_alignof || + op == IASTTypeIdExpression.op_typeof => + val call = callNode( - unary, - nodeSignature(unary), - operatorMethod, - operatorMethod, - if operatorMethod == Operators.addressOf || operatorMethod == Operators.indirectFieldAccess - then DispatchTypes.DYNAMIC_DISPATCH - else DispatchTypes.STATIC_DISPATCH + typeId, + code(typeId), + Operators.sizeOf, + Operators.sizeOf, + DispatchTypes.STATIC_DISPATCH ) - val operand = nullSafeAst(unary.getOperand) - callAst(cpgUnary, List(operand)) - end astForUnaryExpression - - private def astForTypeIdExpression(typeId: IASTTypeIdExpression): Ast = - typeId.getOperator match - case op - if op == IASTTypeIdExpression.op_sizeof || - op == IASTTypeIdExpression.op_sizeofParameterPack || - op == IASTTypeIdExpression.op_typeid || - op == IASTTypeIdExpression.op_alignof || - op == IASTTypeIdExpression.op_typeof => - val call = - callNode( - typeId, - nodeSignature(typeId), - Operators.sizeOf, - Operators.sizeOf, - DispatchTypes.STATIC_DISPATCH - ) - val arg = astForNode(typeId.getTypeId.getDeclSpecifier) - callAst(call, List(arg)) - case _ => notHandledYet(typeId) - - private def astForConditionalExpression(expr: IASTConditionalExpression): Ast = - val name = Operators.conditional - val call = callNode(expr, nodeSignature(expr), name, name, DispatchTypes.STATIC_DISPATCH) - - val condAst = nullSafeAst(expr.getLogicalConditionExpression) - val posAst = nullSafeAst(expr.getPositiveResultExpression) - val negAst = nullSafeAst(expr.getNegativeResultExpression) - - val children = List(condAst, posAst, negAst) - callAst(call, children) - - private def astForArrayIndexExpression(arrayIndexExpression: IASTArraySubscriptExpression) - : Ast = - val name = Operators.indirectIndexAccess - val cpgArrayIndexing = - callNode( - arrayIndexExpression, - nodeSignature(arrayIndexExpression), - name, - name, - DispatchTypes.STATIC_DISPATCH - ) - - val expr = astForExpression(arrayIndexExpression.getArrayExpression) - val arg = astForNode(arrayIndexExpression.getArgument) - callAst(cpgArrayIndexing, List(expr, arg)) - - private def astForCastExpression(castExpression: IASTCastExpression): Ast = - val cpgCastExpression = - callNode( - castExpression, - nodeSignature(castExpression), - Operators.cast, - Operators.cast, - DispatchTypes.STATIC_DISPATCH - ) - - val expr = astForExpression(castExpression.getOperand) - val argNode = castExpression.getTypeId - val arg = unknownNode(argNode, nodeSignature(argNode)) - - callAst(cpgCastExpression, List(Ast(arg), expr)) - - private def astsForConstructorInitializer(initializer: IASTInitializer): List[Ast] = - initializer match - case init: ICPPASTConstructorInitializer => - init.getArguments.toList.map(x => astForNode(x)) - case _ => Nil // null or unexpected type - - private def astsForInitializerPlacements(initializerPlacements: Array[IASTInitializerClause]) - : List[Ast] = - if initializerPlacements != null then initializerPlacements.toList.map(x => astForNode(x)) - else Nil - - private def astForNewExpression(newExpression: ICPPASTNewExpression): Ast = - val name = ".new" - val cpgNewExpression = - callNode( - newExpression, - nodeSignature(newExpression), - name, - name, - DispatchTypes.STATIC_DISPATCH - ) - - val typeId = newExpression.getTypeId - if newExpression.isArrayAllocation then - val cpgTypeId = astForIdentifier(typeId.getDeclSpecifier) - Ast(cpgNewExpression).withChild(cpgTypeId).withArgEdge( - cpgNewExpression, - cpgTypeId.root.get - ) - else - val cpgTypeId = astForIdentifier(typeId.getDeclSpecifier) - val args = astsForConstructorInitializer(newExpression.getInitializer) ++ - astsForInitializerPlacements(newExpression.getPlacementArguments) - callAst(cpgNewExpression, List(cpgTypeId) ++ args) - end astForNewExpression - - private def astForDeleteExpression(delExpression: ICPPASTDeleteExpression): Ast = - val name = Operators.delete - val cpgDeleteNode = - callNode( - delExpression, - nodeSignature(delExpression), - name, - name, - DispatchTypes.STATIC_DISPATCH - ) - val arg = astForExpression(delExpression.getOperand) - callAst(cpgDeleteNode, List(arg)) - - private def astForTypeIdInitExpression(typeIdInit: IASTTypeIdInitializerExpression): Ast = - val name = Operators.cast - val cpgCastExpression = - callNode( - typeIdInit, - nodeSignature(typeIdInit), - name, - name, - DispatchTypes.STATIC_DISPATCH - ) - - val typeAst = unknownNode(typeIdInit.getTypeId, nodeSignature(typeIdInit.getTypeId)) - val expr = astForNode(typeIdInit.getInitializer) - callAst(cpgCastExpression, List(Ast(typeAst), expr)) - - private def astForConstructorExpression(c: ICPPASTSimpleTypeConstructorExpression): Ast = - val name = c.getDeclSpecifier.toString - val callNode_ = callNode(c, nodeSignature(c), name, name, DispatchTypes.STATIC_DISPATCH) - val arg = astForNode(c.getInitializer) - callAst(callNode_, List(arg)) - - private def astForCompoundStatementExpression( - compoundExpression: IGNUASTCompoundStatementExpression - ): Ast = - nullSafeAst(compoundExpression.getCompoundStatement).headOption.getOrElse(Ast()) - - private def astForPackExpansionExpression( - packExpansionExpression: ICPPASTPackExpansionExpression - ): Ast = - astForExpression(packExpansionExpression.getPattern) - - protected def astForExpression(expression: IASTExpression): Ast = - if config.includeFunctionBodies then - astForExpressionFull(expression) - else astForExpressionPartial(expression) - - protected def astForExpressionFull(expression: IASTExpression): Ast = - val r = expression match - case lit: IASTLiteralExpression => astForLiteral(lit) - case un: IASTUnaryExpression => astForUnaryExpression(un) - case bin: IASTBinaryExpression => astForBinaryExpression(bin) - case exprList: IASTExpressionList => astForExpressionList(exprList) - case idExpr: IASTIdExpression => astForIdExpression(idExpr) - case call: IASTFunctionCallExpression => astForCallExpression(call) - case typeId: IASTTypeIdExpression => astForTypeIdExpression(typeId) - case fieldRef: IASTFieldReference => astForFieldReference(fieldRef) - case expr: IASTConditionalExpression => astForConditionalExpression(expr) - case arr: IASTArraySubscriptExpression => astForArrayIndexExpression(arr) - case castExpression: IASTCastExpression => astForCastExpression(castExpression) - case newExpression: ICPPASTNewExpression => astForNewExpression(newExpression) - case delExpression: ICPPASTDeleteExpression => astForDeleteExpression(delExpression) - case typeIdInit: IASTTypeIdInitializerExpression => - astForTypeIdInitExpression(typeIdInit) - case c: ICPPASTSimpleTypeConstructorExpression => astForConstructorExpression(c) - case lambdaExpression: ICPPASTLambdaExpression => - astForMethodRefForLambda(lambdaExpression) - case cExpr: IGNUASTCompoundStatementExpression => - astForCompoundStatementExpression(cExpr) - case pExpr: ICPPASTPackExpansionExpression => astForPackExpansionExpression(pExpr) - case _ => notHandledYet(expression) - asChildOfMacroCall(expression, r) - end astForExpressionFull - - protected def astForExpressionPartial(expression: IASTExpression): Ast = - val r = expression match - case call: IASTFunctionCallExpression => astForCallExpression(call) - case typeId: IASTTypeIdExpression => astForTypeIdExpression(typeId) - case fieldRef: IASTFieldReference => astForFieldReference(fieldRef) - case newExpression: ICPPASTNewExpression => astForNewExpression(newExpression) - case typeIdInit: IASTTypeIdInitializerExpression => - astForTypeIdInitExpression(typeIdInit) - case c: ICPPASTSimpleTypeConstructorExpression => astForConstructorExpression(c) - case _ => notHandledYet(expression) - asChildOfMacroCall(expression, r) - - private def astForIdExpression(idExpression: IASTIdExpression): Ast = idExpression.getName match - case name: CPPASTQualifiedName => astForQualifiedName(name) - case _ => astForIdentifier(idExpression) - - protected def astForStaticAssert(a: ICPPASTStaticAssertDeclaration): Ast = - val name = "static_assert" - val call = callNode(a, nodeSignature(a), name, name, DispatchTypes.STATIC_DISPATCH) - val cond = nullSafeAst(a.getCondition) - val messg = nullSafeAst(a.getMessage) - callAst(call, List(cond, messg)) + val arg = astForNode(typeId.getTypeId.getDeclSpecifier) + callAst(call, List(arg)) + case _ => notHandledYet(typeId) + + private def astForConditionalExpression(expr: IASTConditionalExpression): Ast = + val name = Operators.conditional + val call = callNode(expr, code(expr), name, name, DispatchTypes.STATIC_DISPATCH) + + val condAst = nullSafeAst(expr.getLogicalConditionExpression) + val posAst = nullSafeAst(expr.getPositiveResultExpression) + val negAst = nullSafeAst(expr.getNegativeResultExpression) + + val children = List(condAst, posAst, negAst) + callAst(call, children) + + private def astForArrayIndexExpression(arrayIndexExpression: IASTArraySubscriptExpression): Ast = + val name = Operators.indirectIndexAccess + val cpgArrayIndexing = + callNode( + arrayIndexExpression, + code(arrayIndexExpression), + name, + name, + DispatchTypes.STATIC_DISPATCH + ) + + val expr = astForExpression(arrayIndexExpression.getArrayExpression) + val arg = astForNode(arrayIndexExpression.getArgument) + callAst(cpgArrayIndexing, List(expr, arg)) + + private def astForCastExpression(castExpression: IASTCastExpression): Ast = + val cpgCastExpression = + callNode( + castExpression, + code(castExpression), + Operators.cast, + Operators.cast, + DispatchTypes.STATIC_DISPATCH + ) + + val expr = astForExpression(castExpression.getOperand) + val argNode = castExpression.getTypeId + val arg = unknownNode(argNode, code(argNode)) + + callAst(cpgCastExpression, List(Ast(arg), expr)) + + private def astsForConstructorInitializer(initializer: IASTInitializer): List[Ast] = + initializer match + case init: ICPPASTConstructorInitializer => init.getArguments.toList.map(x => astForNode(x)) + case _ => Nil // null or unexpected type + + private def astsForInitializerPlacements(initializerPlacements: Array[IASTInitializerClause]) + : List[Ast] = + if initializerPlacements != null then initializerPlacements.toList.map(x => astForNode(x)) + else Nil + + private def astForNewExpression(newExpression: ICPPASTNewExpression): Ast = + val name = ".new" + val cpgNewExpression = + callNode(newExpression, code(newExpression), name, name, DispatchTypes.STATIC_DISPATCH) + + val typeId = newExpression.getTypeId + if newExpression.isArrayAllocation then + val cpgTypeId = astForIdentifier(typeId.getDeclSpecifier) + Ast(cpgNewExpression).withChild(cpgTypeId).withArgEdge(cpgNewExpression, cpgTypeId.root.get) + else + val cpgTypeId = astForIdentifier(typeId.getDeclSpecifier) + val args = astsForConstructorInitializer(newExpression.getInitializer) ++ + astsForInitializerPlacements(newExpression.getPlacementArguments) + callAst(cpgNewExpression, List(cpgTypeId) ++ args) + + private def astForDeleteExpression(delExpression: ICPPASTDeleteExpression): Ast = + val name = Operators.delete + val cpgDeleteNode = + callNode(delExpression, code(delExpression), name, name, DispatchTypes.STATIC_DISPATCH) + val arg = astForExpression(delExpression.getOperand) + callAst(cpgDeleteNode, List(arg)) + + private def astForTypeIdInitExpression(typeIdInit: IASTTypeIdInitializerExpression): Ast = + val name = Operators.cast + val cpgCastExpression = + callNode(typeIdInit, code(typeIdInit), name, name, DispatchTypes.STATIC_DISPATCH) + + val typeAst = unknownNode(typeIdInit.getTypeId, code(typeIdInit.getTypeId)) + val expr = astForNode(typeIdInit.getInitializer) + callAst(cpgCastExpression, List(Ast(typeAst), expr)) + + private def astForConstructorExpression(c: ICPPASTSimpleTypeConstructorExpression): Ast = + val name = c.getDeclSpecifier.toString + val callNode_ = callNode(c, code(c), name, name, DispatchTypes.STATIC_DISPATCH) + val arg = astForNode(c.getInitializer) + callAst(callNode_, List(arg)) + + private def astForCompoundStatementExpression( + compoundExpression: IGNUASTCompoundStatementExpression + ): Ast = + nullSafeAst(compoundExpression.getCompoundStatement).headOption.getOrElse(Ast()) + + private def astForPackExpansionExpression(packExpansionExpression: ICPPASTPackExpansionExpression) + : Ast = + astForExpression(packExpansionExpression.getPattern) + + private def astForIdExpression(idExpression: IASTIdExpression): Ast = idExpression.getName match + case name: CPPASTQualifiedName => astForQualifiedName(name) + case _ => astForIdentifier(idExpression) end AstForExpressionsCreator diff --git a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/astcreation/AstForFunctionsCreator.scala b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/astcreation/AstForFunctionsCreator.scala index 73ec30ce..aaf7ff49 100644 --- a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/astcreation/AstForFunctionsCreator.scala +++ b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/astcreation/AstForFunctionsCreator.scala @@ -1,10 +1,11 @@ package io.appthreat.c2cpg.astcreation +import io.appthreat.x2cpg.Defines as X2CpgDefines import io.appthreat.x2cpg.datastructures.Stack.* import io.appthreat.x2cpg.utils.NodeBuilders.newModifierNode import io.appthreat.x2cpg.{Ast, ValidationMode} -import io.shiftleft.codepropertygraph.generated.{EvaluationStrategies, ModifierTypes} import io.shiftleft.codepropertygraph.generated.nodes.* +import io.shiftleft.codepropertygraph.generated.{EvaluationStrategies, ModifierTypes} import org.apache.commons.lang.StringUtils import org.eclipse.cdt.core.dom.ast.* import org.eclipse.cdt.core.dom.ast.cpp.ICPPASTLambdaExpression @@ -21,260 +22,286 @@ import scala.annotation.tailrec import scala.collection.mutable trait AstForFunctionsCreator(implicit withSchemaValidation: ValidationMode): - this: AstCreator => - - private val seenFunctionSignatures = mutable.HashSet.empty[String] - - protected def astForMethodRefForLambda(lambdaExpression: ICPPASTLambdaExpression): Ast = - val filename = fileName(lambdaExpression) - - val returnType = lambdaExpression.getDeclarator match - case declarator: IASTDeclarator => - declarator.getTrailingReturnType match - case id: IASTTypeId => typeForDeclSpecifier(id.getDeclSpecifier) - case null => Defines.anyTypeName - case null => Defines.anyTypeName - val (name, fullname) = uniqueName("lambda", "", fullName(lambdaExpression)) - val signature = - s"$returnType ${fullNameWithoutLocation(fullname)} ${parameterListSignature(lambdaExpression)}" - val code = nodeSignature(lambdaExpression) - val methodNode_ = - methodNode(lambdaExpression, name, code, fullname, Some(signature), filename) - - scope.pushNewScope(methodNode_) - val parameterNodes = withIndex(parameters(lambdaExpression.getDeclarator)) { (p, i) => - parameterNode(p, i) - } - setVariadic(parameterNodes, lambdaExpression) - - scope.popScope() - - val astForLambda = methodAst( - methodNode_, - parameterNodes.map(Ast(_)), - astForMethodBody(Option(lambdaExpression.getBody)), - newMethodReturnNode(lambdaExpression, registerType(returnType)) - ) - val typeDeclAst = - createFunctionTypeAndTypeDecl(lambdaExpression, methodNode_, name, fullname, signature) - Ast.storeInDiffGraph(astForLambda.merge(typeDeclAst), diffGraph) - - Ast(methodRefNode(lambdaExpression, code, fullname, methodNode_.astParentFullName)) - end astForMethodRefForLambda - - protected def astForFunctionDeclarator(funcDecl: IASTFunctionDeclarator): Ast = - val returnType = typeForDeclSpecifier( - funcDecl.getParent.asInstanceOf[IASTSimpleDeclaration].getDeclSpecifier - ) - val name = shortName(funcDecl) - val fullname = fullName(funcDecl) - val templateParams = templateParameters(funcDecl).getOrElse("") - val signature = - s"$returnType ${fullNameWithoutLocation(fullname)}$templateParams ${parameterListSignature(funcDecl)}" - - if seenFunctionSignatures.add(signature) then - val code = nodeSignature(funcDecl.getParent) - val filename = fileName(funcDecl) - val methodNode_ = methodNode(funcDecl, name, code, fullname, Some(signature), filename) - - scope.pushNewScope(methodNode_) - - val parameterNodes = withIndex(parameters(funcDecl)) { (p, i) => - parameterNode(p, i) - } - setVariadic(parameterNodes, funcDecl) - - scope.popScope() - - val stubAst = methodStubAst( - methodNode_, - parameterNodes, - newMethodReturnNode(funcDecl, registerType(returnType)) + this: AstCreator => + + private val seenFunctionFullnames = mutable.HashSet.empty[String] + + protected def astForMethodRefForLambda(lambdaExpression: ICPPASTLambdaExpression): Ast = + val filename = fileName(lambdaExpression) + + val returnType = lambdaExpression.getDeclarator match + case declarator: IASTDeclarator => + declarator.getTrailingReturnType match + case id: IASTTypeId => typeForDeclSpecifier(id.getDeclSpecifier) + case null => Defines.anyTypeName + case null => Defines.anyTypeName + val (name, fullname) = uniqueName("lambda", "", fullName(lambdaExpression)) + val signature = + s"$returnType ${parameterListSignature(lambdaExpression)}" + val code = nodeSignature(lambdaExpression) + val methodNode_ = + methodNode(lambdaExpression, name, code, fullname, Some(signature), filename) + + scope.pushNewScope(methodNode_) + val parameterNodes = withIndex(parameters(lambdaExpression.getDeclarator)) { (p, i) => + parameterNode(p, i) + } + setVariadic(parameterNodes, lambdaExpression) + + scope.popScope() + + val astForLambda = methodAst( + methodNode_, + parameterNodes.map(Ast(_)), + astForMethodBody(Option(lambdaExpression.getBody)), + newMethodReturnNode(lambdaExpression, registerType(returnType)) + ) + val typeDeclAst = + createFunctionTypeAndTypeDecl(lambdaExpression, methodNode_, name, fullname, signature) + Ast.storeInDiffGraph(astForLambda.merge(typeDeclAst), diffGraph) + + Ast(methodRefNode(lambdaExpression, code, fullname, methodNode_.astParentFullName)) + end astForMethodRefForLambda + + protected def astForFunctionDeclarator(funcDecl: IASTFunctionDeclarator): Ast = + funcDecl.getName.resolveBinding() match + case function: IFunction => + val returnType = typeForDeclSpecifier( + funcDecl.getParent.asInstanceOf[IASTSimpleDeclaration].getDeclSpecifier ) - val typeDeclAst = - createFunctionTypeAndTypeDecl(funcDecl, methodNode_, name, fullname, signature) - stubAst.merge(typeDeclAst) - else + val name = shortName(funcDecl) + val fullname = fullName(funcDecl) + val fixedName = if name.isEmpty then + nextClosureName() + else name + val fixedFullName = if fullname.isEmpty then + s"${X2CpgDefines.UnresolvedNamespace}.$name" + else fullname + val templateParams = templateParameters(funcDecl).getOrElse("") + val signature = + s"$returnType${parameterListSignature(funcDecl)}" + + if seenFunctionFullnames.add(fullname) then + val name = shortName(funcDecl) + val codeString = code(funcDecl.getParent) + val filename = fileName(funcDecl) + val methodNode_ = methodNode( + funcDecl, + fixedName, + codeString, + fixedFullName, + Some(signature), + filename + ) + + scope.pushNewScope(methodNode_) + + val parameterNodes = withIndex(parameters(funcDecl)) { (p, i) => + parameterNode(p, i) + } + setVariadic(parameterNodes, funcDecl) + + scope.popScope() + + val stubAst = + methodStubAst( + methodNode_, + parameterNodes, + newMethodReturnNode(funcDecl, registerType(returnType)) + ) + val typeDeclAst = createFunctionTypeAndTypeDecl( + funcDecl, + methodNode_, + fixedName, + fixedFullName, + signature + ) + stubAst.merge(typeDeclAst) + else + Ast() + end if + case field: IField => + Ast() + case typeDef: ITypedef => Ast() - end if - end astForFunctionDeclarator - - protected def astForFunctionDefinition(funcDef: IASTFunctionDefinition): Ast = - val filename = fileName(funcDef) - val returnType = if isCppConstructor(funcDef) then - typeFor(funcDef.asInstanceOf[ - CPPASTFunctionDefinition - ].getMemberInitializers.head.getInitializer) - else typeForDeclSpecifier(funcDef.getDeclSpecifier) - val name = shortName(funcDef) - val fullname = fullName(funcDef) - val templateParams = templateParameters(funcDef).getOrElse("") - - val signature = - s"$returnType ${fullNameWithoutLocation(fullname)}$templateParams ${parameterListSignature(funcDef)}" - seenFunctionSignatures.add(signature) - - val code = nodeSignature(funcDef) - val methodNode_ = methodNode(funcDef, name, code, fullname, Some(signature), filename) - - methodAstParentStack.push(methodNode_) - scope.pushNewScope(methodNode_) - - val parameterNodes = withIndex(parameters(funcDef)) { (p, i) => - parameterNode(p, i) - } - setVariadic(parameterNodes, funcDef) - val modifiers = if isCppConstructor(funcDef) then - List(newModifierNode(ModifierTypes.CONSTRUCTOR), newModifierNode(ModifierTypes.PUBLIC)) - else Nil - val astForMethod = methodAst( - methodNode_, - parameterNodes.map(Ast(_)), - astForMethodBody(Option(funcDef.getBody)), - newMethodReturnNode(funcDef, registerType(returnType)), - modifiers = modifiers + end astForFunctionDeclarator + + protected def astForFunctionDefinition(funcDef: IASTFunctionDefinition): Ast = + val filename = fileName(funcDef) + val returnType = if isCppConstructor(funcDef) then + typeFor(funcDef.asInstanceOf[ + CPPASTFunctionDefinition + ].getMemberInitializers.head.getInitializer) + else typeForDeclSpecifier(funcDef.getDeclSpecifier) + val name = shortName(funcDef) + val fullname = fullName(funcDef) + val templateParams = templateParameters(funcDef).getOrElse("") + + val signature = + s"$returnType $templateParams ${parameterListSignature(funcDef)}" + seenFunctionFullnames.add(fullname) + + val code = nodeSignature(funcDef) + val methodNode_ = methodNode(funcDef, name, code, fullname, Some(signature), filename) + + methodAstParentStack.push(methodNode_) + scope.pushNewScope(methodNode_) + + val parameterNodes = withIndex(parameters(funcDef)) { (p, i) => + parameterNode(p, i) + } + setVariadic(parameterNodes, funcDef) + val modifiers = if isCppConstructor(funcDef) then + List(newModifierNode(ModifierTypes.CONSTRUCTOR), newModifierNode(ModifierTypes.PUBLIC)) + else Nil + val astForMethod = methodAst( + methodNode_, + parameterNodes.map(Ast(_)), + astForMethodBody(Option(funcDef.getBody)), + newMethodReturnNode(funcDef, registerType(returnType)), + modifiers = modifiers + ) + + scope.popScope() + methodAstParentStack.pop() + + val typeDeclAst = + createFunctionTypeAndTypeDecl(funcDef, methodNode_, name, fullname, signature) + astForMethod.merge(typeDeclAst) + end astForFunctionDefinition + + private def createFunctionTypeAndTypeDecl( + node: IASTNode, + method: NewMethod, + methodName: String, + methodFullName: String, + signature: String + ): Ast = + val normalizedName = StringUtils.normalizeSpace(methodName) + val normalizedFullName = StringUtils.normalizeSpace(methodFullName) + + val parentNode: NewTypeDecl = methodAstParentStack.collectFirst { case t: NewTypeDecl => + t + }.getOrElse { + val astParentType = methodAstParentStack.head.label + val astParentFullName = methodAstParentStack.head.properties("FULL_NAME").toString + val typeDeclNode_ = typeDeclNode( + node, + normalizedName, + normalizedFullName, + method.filename, + normalizedName, + astParentType, + astParentFullName ) - - scope.popScope() - methodAstParentStack.pop() - - val typeDeclAst = - createFunctionTypeAndTypeDecl(funcDef, methodNode_, name, fullname, signature) - astForMethod.merge(typeDeclAst) - end astForFunctionDefinition - - private def createFunctionTypeAndTypeDecl( - node: IASTNode, - method: NewMethod, - methodName: String, - methodFullName: String, - signature: String - ): Ast = - val normalizedName = StringUtils.normalizeSpace(methodName) - val normalizedFullName = StringUtils.normalizeSpace(methodFullName) - - val parentNode: NewTypeDecl = methodAstParentStack.collectFirst { case t: NewTypeDecl => - t - }.getOrElse { - val astParentType = methodAstParentStack.head.label - val astParentFullName = methodAstParentStack.head.properties("FULL_NAME").toString - val typeDeclNode_ = typeDeclNode( - node, - normalizedName, - normalizedFullName, - method.filename, - normalizedName, - astParentType, - astParentFullName - ) - Ast.storeInDiffGraph(Ast(typeDeclNode_), diffGraph) - typeDeclNode_ - } - - method.astParentFullName = parentNode.fullName - method.astParentType = parentNode.label - val functionBinding = NewBinding().name(normalizedName).methodFullName( - normalizedFullName - ).signature(signature) - Ast(functionBinding).withBindsEdge(parentNode, functionBinding).withRefEdge( - functionBinding, - method + Ast.storeInDiffGraph(Ast(typeDeclNode_), diffGraph) + typeDeclNode_ + } + + method.astParentFullName = parentNode.fullName + method.astParentType = parentNode.label + val functionBinding = NewBinding().name(normalizedName).methodFullName( + normalizedFullName + ).signature(signature) + Ast(functionBinding).withBindsEdge(parentNode, functionBinding).withRefEdge( + functionBinding, + method + ) + end createFunctionTypeAndTypeDecl + + private def parameters(functionNode: IASTNode): Seq[IASTNode] = functionNode match + case arr: IASTArrayDeclarator => parameters(arr.getNestedDeclarator) + case decl: CPPASTFunctionDeclarator => + decl.getParameters.toIndexedSeq ++ parameters(decl.getNestedDeclarator) + case decl: CASTFunctionDeclarator => + decl.getParameters.toIndexedSeq ++ parameters(decl.getNestedDeclarator) + case defn: IASTFunctionDefinition => parameters(defn.getDeclarator) + case lambdaExpression: ICPPASTLambdaExpression => parameters(lambdaExpression.getDeclarator) + case knr: ICASTKnRFunctionDeclarator => knr.getParameterDeclarations.toIndexedSeq + case _: IASTDeclarator => Seq.empty + case other if other != null => notHandledYet(other); Seq.empty + case null => Seq.empty + + @tailrec + private def isVariadic(functionNode: IASTNode): Boolean = functionNode match + case decl: CPPASTFunctionDeclarator => decl.takesVarArgs() + case decl: CASTFunctionDeclarator => decl.takesVarArgs() + case defn: IASTFunctionDefinition => isVariadic(defn.getDeclarator) + case lambdaExpression: ICPPASTLambdaExpression => isVariadic(lambdaExpression.getDeclarator) + case _ => false + + private def parameterListSignature(func: IASTNode): String = + val variadic = if isVariadic(func) then "..." else "" + val elements = parameters(func).map { + case p: IASTParameterDeclaration => typeForDeclSpecifier(p.getDeclSpecifier) + case other => typeForDeclSpecifier(other) + } + s"(${elements.mkString(",")}$variadic)" + + private def setVariadic(parameterNodes: Seq[NewMethodParameterIn], func: IASTNode): Unit = + parameterNodes.lastOption.foreach { + case p: NewMethodParameterIn if isVariadic(func) => + p.isVariadic = true + p.code = s"${p.code}..." + case _ => + } + + private def fullNameWithoutLocation(fullName: String) = fullName.split(":").last + + private def isCppConstructor(funcDef: IASTFunctionDefinition): Boolean = + funcDef match + case cppFunc: CPPASTFunctionDefinition => cppFunc.getMemberInitializers.nonEmpty + case _ => false + + private def parameterNode(parameter: IASTNode, paramIndex: Int): NewMethodParameterIn = + val (name, code, tpe, variadic) = parameter match + case p: CASTParameterDeclaration => + ( + ASTStringUtil.getSimpleName(p.getDeclarator.getName), + nodeSignature(p), + cleanType(typeForDeclSpecifier(p.getDeclSpecifier)), + false + ) + case p: CPPASTParameterDeclaration => + ( + ASTStringUtil.getSimpleName(p.getDeclarator.getName), + nodeSignature(p), + cleanType(typeForDeclSpecifier(p.getDeclSpecifier)), + p.getDeclarator.declaresParameterPack() + ) + case s: IASTSimpleDeclaration => + ( + s.getDeclarators.headOption + .map(n => ASTStringUtil.getSimpleName(n.getName)) + .getOrElse(uniqueName("parameter", "", "")._1), + nodeSignature(s), + cleanType(typeForDeclSpecifier(s)), + false + ) + case other => + ( + nodeSignature(other), + nodeSignature(other), + cleanType(typeForDeclSpecifier(other)), + false + ) + + val parameterNode = + parameterInNode( + parameter, + name, + code, + paramIndex, + variadic, + EvaluationStrategies.BY_VALUE, + registerType(tpe) ) - end createFunctionTypeAndTypeDecl - - private def parameters(functionNode: IASTNode): Seq[IASTNode] = functionNode match - case arr: IASTArrayDeclarator => parameters(arr.getNestedDeclarator) - case decl: CPPASTFunctionDeclarator => - decl.getParameters.toIndexedSeq ++ parameters(decl.getNestedDeclarator) - case decl: CASTFunctionDeclarator => - decl.getParameters.toIndexedSeq ++ parameters(decl.getNestedDeclarator) - case defn: IASTFunctionDefinition => parameters(defn.getDeclarator) - case lambdaExpression: ICPPASTLambdaExpression => parameters(lambdaExpression.getDeclarator) - case knr: ICASTKnRFunctionDeclarator => knr.getParameterDeclarations.toIndexedSeq - case _: IASTDeclarator => Seq.empty - case other if other != null => notHandledYet(other); Seq.empty - case null => Seq.empty - - @tailrec - private def isVariadic(functionNode: IASTNode): Boolean = functionNode match - case decl: CPPASTFunctionDeclarator => decl.takesVarArgs() - case decl: CASTFunctionDeclarator => decl.takesVarArgs() - case defn: IASTFunctionDefinition => isVariadic(defn.getDeclarator) - case lambdaExpression: ICPPASTLambdaExpression => isVariadic(lambdaExpression.getDeclarator) - case _ => false - - private def parameterListSignature(func: IASTNode): String = - val variadic = if isVariadic(func) then "..." else "" - val elements = parameters(func).map { - case p: IASTParameterDeclaration => typeForDeclSpecifier(p.getDeclSpecifier) - case other => typeForDeclSpecifier(other) - } - s"(${elements.mkString(",")}$variadic)" - - private def setVariadic(parameterNodes: Seq[NewMethodParameterIn], func: IASTNode): Unit = - parameterNodes.lastOption.foreach { - case p: NewMethodParameterIn if isVariadic(func) => - p.isVariadic = true - p.code = s"${p.code}..." - case _ => - } - - private def fullNameWithoutLocation(fullName: String) = fullName.split(":").last - - private def isCppConstructor(funcDef: IASTFunctionDefinition): Boolean = - funcDef match - case cppFunc: CPPASTFunctionDefinition => cppFunc.getMemberInitializers.nonEmpty - case _ => false - - private def parameterNode(parameter: IASTNode, paramIndex: Int): NewMethodParameterIn = - val (name, code, tpe, variadic) = parameter match - case p: CASTParameterDeclaration => - ( - ASTStringUtil.getSimpleName(p.getDeclarator.getName), - nodeSignature(p), - cleanType(typeForDeclSpecifier(p.getDeclSpecifier)), - false - ) - case p: CPPASTParameterDeclaration => - ( - ASTStringUtil.getSimpleName(p.getDeclarator.getName), - nodeSignature(p), - cleanType(typeForDeclSpecifier(p.getDeclSpecifier)), - p.getDeclarator.declaresParameterPack() - ) - case s: IASTSimpleDeclaration => - ( - s.getDeclarators.headOption - .map(n => ASTStringUtil.getSimpleName(n.getName)) - .getOrElse(uniqueName("parameter", "", "")._1), - nodeSignature(s), - cleanType(typeForDeclSpecifier(s)), - false - ) - case other => - ( - nodeSignature(other), - nodeSignature(other), - cleanType(typeForDeclSpecifier(other)), - false - ) - - val parameterNode = - parameterInNode( - parameter, - name, - code, - paramIndex, - variadic, - EvaluationStrategies.BY_VALUE, - registerType(tpe) - ) - scope.addToScope(name, (parameterNode, tpe)) - parameterNode - end parameterNode - - private def astForMethodBody(body: Option[IASTStatement]): Ast = body match - case Some(b: IASTCompoundStatement) => astForBlockStatement(b) - case Some(b) => astForNode(b) - case None => blockAst(NewBlock()) + scope.addToScope(name, (parameterNode, tpe)) + parameterNode + end parameterNode + + private def astForMethodBody(body: Option[IASTStatement]): Ast = body match + case Some(b: IASTCompoundStatement) => astForBlockStatement(b) + case Some(b) => astForNode(b) + case None => blockAst(NewBlock()) end AstForFunctionsCreator diff --git a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/astcreation/AstForPrimitivesCreator.scala b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/astcreation/AstForPrimitivesCreator.scala index 1837c5db..5595d529 100644 --- a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/astcreation/AstForPrimitivesCreator.scala +++ b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/astcreation/AstForPrimitivesCreator.scala @@ -10,156 +10,156 @@ import org.eclipse.cdt.internal.core.dom.parser.cpp.ICPPInternalBinding import org.eclipse.cdt.internal.core.model.ASTStringUtil trait AstForPrimitivesCreator(implicit withSchemaValidation: ValidationMode): - this: AstCreator => - - protected def astForComment(comment: IASTComment): Ast = - Ast(newCommentNode(comment, nodeSignature(comment), fileName(comment))) - - protected def astForLiteral(lit: IASTLiteralExpression): Ast = - val tpe = cleanType(ASTTypeUtil.getType(lit.getExpressionType)) - Ast(literalNode(lit, nodeSignature(lit), registerType(tpe))) - - private def namesForBinding(binding: ICInternalBinding | ICPPInternalBinding) - : (Option[String], Option[String]) = - val definition = binding match - // sadly, there is no common interface defining .getDefinition - case b: ICInternalBinding => b.getDefinition.asInstanceOf[IASTFunctionDeclarator] - case b: ICPPInternalBinding => b.getDefinition.asInstanceOf[IASTFunctionDeclarator] - val typeFullName = definition.getParent match - case d: IASTFunctionDefinition => Some(typeForDeclSpecifier(d.getDeclSpecifier)) - case _ => None - (Some(this.fullName(definition)), typeFullName) - - private def maybeMethodRefForIdentifier(ident: IASTNode): Option[NewMethodRef] = - ident match - case id: IASTIdExpression if id.getName != null => - id.getName.resolveBinding() - val (mayBeFullName, mayBeTypeFullName) = id.getName.getBinding match - case binding: ICInternalBinding - if binding.getDefinition.isInstanceOf[IASTFunctionDeclarator] => - namesForBinding(binding) - case binding: ICPPInternalBinding - if binding.getDefinition.isInstanceOf[IASTFunctionDeclarator] => - namesForBinding(binding) - case _ => (None, None) - for - fullName <- mayBeFullName - typeFullName <- mayBeTypeFullName - yield methodRefNode(ident, code(ident), fullName, typeFullName) - case _ => None - - protected def astForIdentifier(ident: IASTNode): Ast = - maybeMethodRefForIdentifier(ident) match - case Some(ref) => Ast(ref) - case None => - val identifierName = ident match - case id: IASTIdExpression => ASTStringUtil.getSimpleName(id.getName) - case id: IASTName - if ASTStringUtil.getSimpleName(id).isEmpty && id.getBinding != null => - id.getBinding.getName - case id: IASTName if ASTStringUtil.getSimpleName(id).isEmpty => - uniqueName("name", "", "")._1 - case _ => code(ident) - val variableOption = scope.lookupVariable(identifierName) - val identifierTypeName = variableOption match - case Some((_, variableTypeName)) => variableTypeName - case None - if ident.isInstanceOf[IASTName] && ident.asInstanceOf[ - IASTName - ].getBinding != null => - val id = ident.asInstanceOf[IASTName] - id.getBinding match - case v: IVariable => - v.getType match - case f: IFunctionType => f.getReturnType.toString - case other => other.toString - case other => other.getName - case None if ident.isInstanceOf[IASTName] => - typeFor(ident.getParent) - case None => typeFor(ident) - - val node = identifierNode( - ident, - identifierName, - code(ident), - registerType(cleanType(identifierTypeName)) - ) - variableOption match - case Some((variable, _)) => - Ast(node).withRefEdge(node, variable) - case None => Ast(node) - - protected def astForFieldReference(fieldRef: IASTFieldReference): Ast = - val op = if fieldRef.isPointerDereference then Operators.indirectFieldAccess - else Operators.fieldAccess - val ma = callNode( - fieldRef, - nodeSignature(fieldRef), - op, - op, - if fieldRef.isPointerDereference then DispatchTypes.DYNAMIC_DISPATCH - else DispatchTypes.STATIC_DISPATCH - ) - val owner = astForExpression(fieldRef.getFieldOwner) - val member = fieldIdentifierNode( - fieldRef, - fieldRef.getFieldName.toString, - fieldRef.getFieldName.toString - ) - callAst(ma, List(owner, Ast(member))) - - protected def astForArrayModifier(arrMod: IASTArrayModifier): Ast = - astForNode(arrMod.getConstantExpression) - - protected def astForInitializerList(l: IASTInitializerList): Ast = - val op = Operators.arrayInitializer - val initCallNode = callNode(l, nodeSignature(l), op, op, DispatchTypes.STATIC_DISPATCH) - - val MAX_INITIALIZERS = 1000 - val clauses = l.getClauses.slice(0, MAX_INITIALIZERS) - - val args = clauses.toList.map(x => astForNode(x)) - - val ast = callAst(initCallNode, args) - if l.getClauses.length > MAX_INITIALIZERS then - val placeholder = - literalNode(l, "", Defines.anyTypeName).argumentIndex( - MAX_INITIALIZERS - ) - ast.withChild(Ast(placeholder)).withArgEdge(initCallNode, placeholder) - else - ast - - protected def astForQualifiedName(qualId: CPPASTQualifiedName): Ast = - val op = Operators.fieldAccess - val ma = callNode(qualId, nodeSignature(qualId), op, op, DispatchTypes.STATIC_DISPATCH) - - def fieldAccesses(names: List[IASTNode], argIndex: Int = -1): Ast = names match - case Nil => Ast() - case head :: Nil => - astForNode(head) - case head :: tail => - val code = s"${nodeSignature(head)}::${tail.map(nodeSignature).mkString("::")}" - val callNode_ = - callNode(head, nodeSignature(head), op, op, DispatchTypes.STATIC_DISPATCH) - .argumentIndex(argIndex) - callNode_.code = code - val arg1 = astForNode(head) - val arg2 = fieldAccesses(tail) - callAst(callNode_, List(arg1, arg2)) - - val qualifier = fieldAccesses(qualId.getQualifier.toIndexedSeq.toList) - - val owner = if qualifier != Ast() then - qualifier - else - Ast(literalNode(qualId.getLastName, "", Defines.anyTypeName)) - - val member = fieldIdentifierNode( - qualId.getLastName, - fixQualifiedName(qualId.getLastName.toString), - qualId.getLastName.toString - ) - callAst(ma, List(owner, Ast(member))) - end astForQualifiedName + this: AstCreator => + + protected def astForComment(comment: IASTComment): Ast = + Ast(newCommentNode(comment, nodeSignature(comment), fileName(comment))) + + protected def astForLiteral(lit: IASTLiteralExpression): Ast = + val tpe = cleanType(ASTTypeUtil.getType(lit.getExpressionType)) + Ast(literalNode(lit, nodeSignature(lit), registerType(tpe))) + + private def namesForBinding(binding: ICInternalBinding | ICPPInternalBinding) + : (Option[String], Option[String]) = + val definition = binding match + // sadly, there is no common interface defining .getDefinition + case b: ICInternalBinding => b.getDefinition.asInstanceOf[IASTFunctionDeclarator] + case b: ICPPInternalBinding => b.getDefinition.asInstanceOf[IASTFunctionDeclarator] + val typeFullName = definition.getParent match + case d: IASTFunctionDefinition => Some(typeForDeclSpecifier(d.getDeclSpecifier)) + case _ => None + (Some(this.fullName(definition)), typeFullName) + + private def maybeMethodRefForIdentifier(ident: IASTNode): Option[NewMethodRef] = + ident match + case id: IASTIdExpression if id.getName != null => + id.getName.resolveBinding() + val (mayBeFullName, mayBeTypeFullName) = id.getName.getBinding match + case binding: ICInternalBinding + if binding.getDefinition.isInstanceOf[IASTFunctionDeclarator] => + namesForBinding(binding) + case binding: ICPPInternalBinding + if binding.getDefinition.isInstanceOf[IASTFunctionDeclarator] => + namesForBinding(binding) + case _ => (None, None) + for + fullName <- mayBeFullName + typeFullName <- mayBeTypeFullName + yield methodRefNode(ident, code(ident), fullName, registerType(cleanType(typeFullName))) + case _ => None + + protected def astForIdentifier(ident: IASTNode): Ast = + maybeMethodRefForIdentifier(ident) match + case Some(ref) => Ast(ref) + case None => + val identifierName = ident match + case id: IASTIdExpression => ASTStringUtil.getSimpleName(id.getName) + case id: IASTName + if ASTStringUtil.getSimpleName(id).isEmpty && id.getBinding != null => + id.getBinding.getName + case id: IASTName if ASTStringUtil.getSimpleName(id).isEmpty => + uniqueName("name", "", "")._1 + case _ => code(ident) + val variableOption = scope.lookupVariable(identifierName) + val identifierTypeName = variableOption match + case Some((_, variableTypeName)) => variableTypeName + case None + if ident.isInstanceOf[IASTName] && ident.asInstanceOf[ + IASTName + ].getBinding != null => + val id = ident.asInstanceOf[IASTName] + id.getBinding match + case v: IVariable => + v.getType match + case f: IFunctionType => f.getReturnType.toString + case other => other.toString + case other => other.getName + case None if ident.isInstanceOf[IASTName] => + typeFor(ident.getParent) + case None => typeFor(ident) + + val node = identifierNode( + ident, + identifierName, + code(ident), + registerType(cleanType(identifierTypeName)) + ) + variableOption match + case Some((variable, _)) => + Ast(node).withRefEdge(node, variable) + case None => Ast(node) + + protected def astForFieldReference(fieldRef: IASTFieldReference): Ast = + val op = if fieldRef.isPointerDereference then Operators.indirectFieldAccess + else Operators.fieldAccess + val ma = callNode( + fieldRef, + nodeSignature(fieldRef), + op, + op, + if fieldRef.isPointerDereference then DispatchTypes.DYNAMIC_DISPATCH + else DispatchTypes.STATIC_DISPATCH + ) + val owner = astForExpression(fieldRef.getFieldOwner) + val member = fieldIdentifierNode( + fieldRef, + fieldRef.getFieldName.toString, + fieldRef.getFieldName.toString + ) + callAst(ma, List(owner, Ast(member))) + + protected def astForArrayModifier(arrMod: IASTArrayModifier): Ast = + astForNode(arrMod.getConstantExpression) + + protected def astForInitializerList(l: IASTInitializerList): Ast = + val op = Operators.arrayInitializer + val initCallNode = callNode(l, nodeSignature(l), op, op, DispatchTypes.STATIC_DISPATCH) + + val MAX_INITIALIZERS = 1000 + val clauses = l.getClauses.slice(0, MAX_INITIALIZERS) + + val args = clauses.toList.map(x => astForNode(x)) + + val ast = callAst(initCallNode, args) + if l.getClauses.length > MAX_INITIALIZERS then + val placeholder = + literalNode(l, "", Defines.anyTypeName).argumentIndex( + MAX_INITIALIZERS + ) + ast.withChild(Ast(placeholder)).withArgEdge(initCallNode, placeholder) + else + ast + + protected def astForQualifiedName(qualId: CPPASTQualifiedName): Ast = + val op = Operators.fieldAccess + val ma = callNode(qualId, nodeSignature(qualId), op, op, DispatchTypes.STATIC_DISPATCH) + + def fieldAccesses(names: List[IASTNode], argIndex: Int = -1): Ast = names match + case Nil => Ast() + case head :: Nil => + astForNode(head) + case head :: tail => + val code = s"${nodeSignature(head)}::${tail.map(nodeSignature).mkString("::")}" + val callNode_ = + callNode(head, nodeSignature(head), op, op, DispatchTypes.STATIC_DISPATCH) + .argumentIndex(argIndex) + callNode_.code = code + val arg1 = astForNode(head) + val arg2 = fieldAccesses(tail) + callAst(callNode_, List(arg1, arg2)) + + val qualifier = fieldAccesses(qualId.getQualifier.toIndexedSeq.toList) + + val owner = if qualifier != Ast() then + qualifier + else + Ast(literalNode(qualId.getLastName, "", Defines.anyTypeName)) + + val member = fieldIdentifierNode( + qualId.getLastName, + fixQualifiedName(qualId.getLastName.toString), + qualId.getLastName.toString + ) + callAst(ma, List(owner, Ast(member))) + end astForQualifiedName end AstForPrimitivesCreator diff --git a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/astcreation/AstForStatementsCreator.scala b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/astcreation/AstForStatementsCreator.scala index aa04f201..5847ded0 100644 --- a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/astcreation/AstForStatementsCreator.scala +++ b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/astcreation/AstForStatementsCreator.scala @@ -12,252 +12,252 @@ import org.eclipse.cdt.internal.core.dom.parser.cpp.CPPASTNamespaceAlias import org.eclipse.cdt.internal.core.model.ASTStringUtil trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode): - this: AstCreator => + this: AstCreator => - import AstCreatorHelper.OptionSafeAst + import AstCreatorHelper.OptionSafeAst - protected def astForBlockStatement(blockStmt: IASTCompoundStatement, order: Int = -1): Ast = - val code = nodeSignature(blockStmt) - val blockCode = if code == "{}" || code.isEmpty then Defines.empty else code - val node = blockNode(blockStmt, blockCode, registerType(Defines.voidTypeName)) - .order(order) - .argumentIndex(order) - scope.pushNewScope(node) - var currOrder = 1 - val childAsts = blockStmt.getStatements.flatMap { stmt => - val r = astsForStatement(stmt, currOrder) - currOrder = currOrder + r.length - r - } - scope.popScope() - blockAst(node, childAsts.toList) + protected def astForBlockStatement(blockStmt: IASTCompoundStatement, order: Int = -1): Ast = + val code = nodeSignature(blockStmt) + val blockCode = if code == "{}" || code.isEmpty then Defines.empty else code + val node = blockNode(blockStmt, blockCode, registerType(Defines.voidTypeName)) + .order(order) + .argumentIndex(order) + scope.pushNewScope(node) + var currOrder = 1 + val childAsts = blockStmt.getStatements.flatMap { stmt => + val r = astsForStatement(stmt, currOrder) + currOrder = currOrder + r.length + r + } + scope.popScope() + blockAst(node, childAsts.toList) - private def astsForDeclarationStatement(decl: IASTDeclarationStatement): Seq[Ast] = - decl.getDeclaration match - case simplDecl: IASTSimpleDeclaration - if simplDecl.getDeclarators.headOption.exists( - _.isInstanceOf[IASTFunctionDeclarator] - ) => - Seq(astForFunctionDeclarator( - simplDecl.getDeclarators.head.asInstanceOf[IASTFunctionDeclarator] - )) - case simplDecl: IASTSimpleDeclaration => - val locals = - simplDecl.getDeclarators.zipWithIndex.toList.map { case (d, i) => - astForDeclarator(simplDecl, d, i) - } - val calls = - simplDecl.getDeclarators.filter(_.getInitializer != null).toList.map { d => - astForInitializer(d, d.getInitializer) - } - locals ++ calls - case s: ICPPASTStaticAssertDeclaration => Seq(astForStaticAssert(s)) - case usingDeclaration: ICPPASTUsingDeclaration => - handleUsingDeclaration(usingDeclaration) - case alias: ICPPASTAliasDeclaration => Seq(astForAliasDeclaration(alias)) - case func: IASTFunctionDefinition => Seq(astForFunctionDefinition(func)) - case alias: CPPASTNamespaceAlias => Seq(astForNamespaceAlias(alias)) - case asm: IASTASMDeclaration => Seq(astForASMDeclaration(asm)) - case _: ICPPASTUsingDirective => Seq.empty - case decl => Seq(astForNode(decl)) + private def astsForDeclarationStatement(decl: IASTDeclarationStatement): Seq[Ast] = + decl.getDeclaration match + case simplDecl: IASTSimpleDeclaration + if simplDecl.getDeclarators.headOption.exists( + _.isInstanceOf[IASTFunctionDeclarator] + ) => + Seq(astForFunctionDeclarator( + simplDecl.getDeclarators.head.asInstanceOf[IASTFunctionDeclarator] + )) + case simplDecl: IASTSimpleDeclaration => + val locals = + simplDecl.getDeclarators.zipWithIndex.toList.map { case (d, i) => + astForDeclarator(simplDecl, d, i) + } + val calls = + simplDecl.getDeclarators.filter(_.getInitializer != null).toList.map { d => + astForInitializer(d, d.getInitializer) + } + locals ++ calls + case s: ICPPASTStaticAssertDeclaration => Seq(astForStaticAssert(s)) + case usingDeclaration: ICPPASTUsingDeclaration => + handleUsingDeclaration(usingDeclaration) + case alias: ICPPASTAliasDeclaration => Seq(astForAliasDeclaration(alias)) + case func: IASTFunctionDefinition => Seq(astForFunctionDefinition(func)) + case alias: CPPASTNamespaceAlias => Seq(astForNamespaceAlias(alias)) + case asm: IASTASMDeclaration => Seq(astForASMDeclaration(asm)) + case _: ICPPASTUsingDirective => Seq.empty + case decl => Seq(astForNode(decl)) - private def astForReturnStatement(ret: IASTReturnStatement): Ast = - val cpgReturn = returnNode(ret, nodeSignature(ret)) - val expr = nullSafeAst(ret.getReturnValue) - Ast(cpgReturn).withChild(expr).withArgEdge(cpgReturn, expr.root) + private def astForReturnStatement(ret: IASTReturnStatement): Ast = + val cpgReturn = returnNode(ret, nodeSignature(ret)) + val expr = nullSafeAst(ret.getReturnValue) + Ast(cpgReturn).withChild(expr).withArgEdge(cpgReturn, expr.root) - private def astForBreakStatement(br: IASTBreakStatement): Ast = - Ast(controlStructureNode(br, ControlStructureTypes.BREAK, nodeSignature(br))) + private def astForBreakStatement(br: IASTBreakStatement): Ast = + Ast(controlStructureNode(br, ControlStructureTypes.BREAK, nodeSignature(br))) - private def astForContinueStatement(cont: IASTContinueStatement): Ast = - Ast(controlStructureNode(cont, ControlStructureTypes.CONTINUE, nodeSignature(cont))) + private def astForContinueStatement(cont: IASTContinueStatement): Ast = + Ast(controlStructureNode(cont, ControlStructureTypes.CONTINUE, nodeSignature(cont))) - private def astForGotoStatement(goto: IASTGotoStatement): Ast = - val code = s"goto ${ASTStringUtil.getSimpleName(goto.getName)};" - Ast(controlStructureNode(goto, ControlStructureTypes.GOTO, code)) + private def astForGotoStatement(goto: IASTGotoStatement): Ast = + val code = s"goto ${ASTStringUtil.getSimpleName(goto.getName)};" + Ast(controlStructureNode(goto, ControlStructureTypes.GOTO, code)) - private def astsForGnuGotoStatement(goto: IGNUASTGotoStatement): Seq[Ast] = - // This is for GNU GOTO labels as values. - // See: https://gcc.gnu.org/onlinedocs/gcc/Labels-as-Values.html - // For such GOTOs we cannot statically determine the target label. As a quick - // hack we simply put edges to all labels found indicated by *. This might be an over-taint. - val code = s"goto *;" - val gotoNode = Ast(controlStructureNode(goto, ControlStructureTypes.GOTO, code)) - val exprNode = nullSafeAst(goto.getLabelNameExpression) - Seq(gotoNode, exprNode) + private def astsForGnuGotoStatement(goto: IGNUASTGotoStatement): Seq[Ast] = + // This is for GNU GOTO labels as values. + // See: https://gcc.gnu.org/onlinedocs/gcc/Labels-as-Values.html + // For such GOTOs we cannot statically determine the target label. As a quick + // hack we simply put edges to all labels found indicated by *. This might be an over-taint. + val code = s"goto *;" + val gotoNode = Ast(controlStructureNode(goto, ControlStructureTypes.GOTO, code)) + val exprNode = nullSafeAst(goto.getLabelNameExpression) + Seq(gotoNode, exprNode) - private def astsForLabelStatement(label: IASTLabelStatement): Seq[Ast] = - val cpgLabel = newJumpTargetNode(label) - val nestedStmts = nullSafeAst(label.getNestedStatement) - Ast(cpgLabel) +: nestedStmts + private def astsForLabelStatement(label: IASTLabelStatement): Seq[Ast] = + val cpgLabel = newJumpTargetNode(label) + val nestedStmts = nullSafeAst(label.getNestedStatement) + Ast(cpgLabel) +: nestedStmts - private def astForDoStatement(doStmt: IASTDoStatement): Ast = - val code = nodeSignature(doStmt) - val doNode = controlStructureNode(doStmt, ControlStructureTypes.DO, code) - val conditionAst = astForConditionExpression(doStmt.getCondition) - val bodyAst = nullSafeAst(doStmt.getBody) - controlStructureAst(doNode, Some(conditionAst), bodyAst, placeConditionLast = true) + private def astForDoStatement(doStmt: IASTDoStatement): Ast = + val code = nodeSignature(doStmt) + val doNode = controlStructureNode(doStmt, ControlStructureTypes.DO, code) + val conditionAst = astForConditionExpression(doStmt.getCondition) + val bodyAst = nullSafeAst(doStmt.getBody) + controlStructureAst(doNode, Some(conditionAst), bodyAst, placeConditionLast = true) - private def astForSwitchStatement(switchStmt: IASTSwitchStatement): Ast = - val code = s"switch(${nullSafeCode(switchStmt.getControllerExpression)})" - val switchNode = controlStructureNode(switchStmt, ControlStructureTypes.SWITCH, code) - val conditionAst = astForConditionExpression(switchStmt.getControllerExpression) - val stmtAsts = nullSafeAst(switchStmt.getBody) - controlStructureAst(switchNode, Some(conditionAst), stmtAsts) + private def astForSwitchStatement(switchStmt: IASTSwitchStatement): Ast = + val code = s"switch(${nullSafeCode(switchStmt.getControllerExpression)})" + val switchNode = controlStructureNode(switchStmt, ControlStructureTypes.SWITCH, code) + val conditionAst = astForConditionExpression(switchStmt.getControllerExpression) + val stmtAsts = nullSafeAst(switchStmt.getBody) + controlStructureAst(switchNode, Some(conditionAst), stmtAsts) - private def astsForCaseStatement(caseStmt: IASTCaseStatement): Seq[Ast] = - val labelNode = newJumpTargetNode(caseStmt) - val stmt = astForConditionExpression(caseStmt.getExpression) - Seq(Ast(labelNode), stmt) + private def astsForCaseStatement(caseStmt: IASTCaseStatement): Seq[Ast] = + val labelNode = newJumpTargetNode(caseStmt) + val stmt = astForConditionExpression(caseStmt.getExpression) + Seq(Ast(labelNode), stmt) - private def astForDefaultStatement(caseStmt: IASTDefaultStatement): Ast = - Ast(newJumpTargetNode(caseStmt)) + private def astForDefaultStatement(caseStmt: IASTDefaultStatement): Ast = + Ast(newJumpTargetNode(caseStmt)) - private def astForTryStatement(tryStmt: ICPPASTTryBlockStatement): Ast = - val cpgTry = controlStructureNode(tryStmt, ControlStructureTypes.TRY, "try") - val body = nullSafeAst(tryStmt.getTryBody) - // All catches must have order 2 for correct control flow generation. - // TODO fix this. Multiple siblings with the same order are invalid - val catches = tryStmt.getCatchHandlers.flatMap { stmt => - astsForStatement(stmt.getCatchBody, 2) - }.toIndexedSeq - Ast(cpgTry).withChildren(body).withChildren(catches) + private def astForTryStatement(tryStmt: ICPPASTTryBlockStatement): Ast = + val cpgTry = controlStructureNode(tryStmt, ControlStructureTypes.TRY, "try") + val body = nullSafeAst(tryStmt.getTryBody) + // All catches must have order 2 for correct control flow generation. + // TODO fix this. Multiple siblings with the same order are invalid + val catches = tryStmt.getCatchHandlers.flatMap { stmt => + astsForStatement(stmt.getCatchBody, 2) + }.toIndexedSeq + Ast(cpgTry).withChildren(body).withChildren(catches) - protected def astsForStatement(statement: IASTStatement, argIndex: Int = -1): Seq[Ast] = - val r = statement match - case expr: IASTExpressionStatement => Seq(astForExpression(expr.getExpression)) - case block: IASTCompoundStatement => Seq(astForBlockStatement(block, argIndex)) - case ifStmt: IASTIfStatement => Seq(astForIf(ifStmt)) - case whileStmt: IASTWhileStatement => Seq(astForWhile(whileStmt)) - case forStmt: IASTForStatement => Seq(astForFor(forStmt)) - case forStmt: ICPPASTRangeBasedForStatement => Seq(astForRangedFor(forStmt)) - case doStmt: IASTDoStatement => Seq(astForDoStatement(doStmt)) - case switchStmt: IASTSwitchStatement => Seq(astForSwitchStatement(switchStmt)) - case ret: IASTReturnStatement => Seq(astForReturnStatement(ret)) - case br: IASTBreakStatement => Seq(astForBreakStatement(br)) - case cont: IASTContinueStatement => Seq(astForContinueStatement(cont)) - case goto: IASTGotoStatement => Seq(astForGotoStatement(goto)) - case goto: IGNUASTGotoStatement => astsForGnuGotoStatement(goto) - case defStmt: IASTDefaultStatement => Seq(astForDefaultStatement(defStmt)) - case tryStmt: ICPPASTTryBlockStatement => Seq(astForTryStatement(tryStmt)) - case caseStmt: IASTCaseStatement => astsForCaseStatement(caseStmt) - case decl: IASTDeclarationStatement => astsForDeclarationStatement(decl) - case label: IASTLabelStatement => astsForLabelStatement(label) - case _: IASTNullStatement => Seq.empty - case _ => Seq(astForNode(statement)) - r.map(x => asChildOfMacroCall(statement, x)) - end astsForStatement + protected def astsForStatement(statement: IASTStatement, argIndex: Int = -1): Seq[Ast] = + val r = statement match + case expr: IASTExpressionStatement => Seq(astForExpression(expr.getExpression)) + case block: IASTCompoundStatement => Seq(astForBlockStatement(block, argIndex)) + case ifStmt: IASTIfStatement => Seq(astForIf(ifStmt)) + case whileStmt: IASTWhileStatement => Seq(astForWhile(whileStmt)) + case forStmt: IASTForStatement => Seq(astForFor(forStmt)) + case forStmt: ICPPASTRangeBasedForStatement => Seq(astForRangedFor(forStmt)) + case doStmt: IASTDoStatement => Seq(astForDoStatement(doStmt)) + case switchStmt: IASTSwitchStatement => Seq(astForSwitchStatement(switchStmt)) + case ret: IASTReturnStatement => Seq(astForReturnStatement(ret)) + case br: IASTBreakStatement => Seq(astForBreakStatement(br)) + case cont: IASTContinueStatement => Seq(astForContinueStatement(cont)) + case goto: IASTGotoStatement => Seq(astForGotoStatement(goto)) + case goto: IGNUASTGotoStatement => astsForGnuGotoStatement(goto) + case defStmt: IASTDefaultStatement => Seq(astForDefaultStatement(defStmt)) + case tryStmt: ICPPASTTryBlockStatement => Seq(astForTryStatement(tryStmt)) + case caseStmt: IASTCaseStatement => astsForCaseStatement(caseStmt) + case decl: IASTDeclarationStatement => astsForDeclarationStatement(decl) + case label: IASTLabelStatement => astsForLabelStatement(label) + case _: IASTNullStatement => Seq.empty + case _ => Seq(astForNode(statement)) + r.map(x => asChildOfMacroCall(statement, x)) + end astsForStatement - private def astForConditionExpression( - expr: IASTExpression, - explicitArgumentIndex: Option[Int] = None - ): Ast = - val ast = expr match - case exprList: IASTExpressionList => - val compareAstBlock = - blockNode(expr, Defines.empty, registerType(Defines.voidTypeName)) - scope.pushNewScope(compareAstBlock) - val compareBlockAstChildren = exprList.getExpressions.toList.map(nullSafeAst) - setArgumentIndices(compareBlockAstChildren) - val compareBlockAst = blockAst(compareAstBlock, compareBlockAstChildren) - scope.popScope() - compareBlockAst - case other => - nullSafeAst(other) - explicitArgumentIndex.foreach { i => - ast.root.foreach { case expr: ExpressionNew => expr.argumentIndex = i } - } - ast - end astForConditionExpression + private def astForConditionExpression( + expr: IASTExpression, + explicitArgumentIndex: Option[Int] = None + ): Ast = + val ast = expr match + case exprList: IASTExpressionList => + val compareAstBlock = + blockNode(expr, Defines.empty, registerType(Defines.voidTypeName)) + scope.pushNewScope(compareAstBlock) + val compareBlockAstChildren = exprList.getExpressions.toList.map(nullSafeAst) + setArgumentIndices(compareBlockAstChildren) + val compareBlockAst = blockAst(compareAstBlock, compareBlockAstChildren) + scope.popScope() + compareBlockAst + case other => + nullSafeAst(other) + explicitArgumentIndex.foreach { i => + ast.root.foreach { case expr: ExpressionNew => expr.argumentIndex = i } + } + ast + end astForConditionExpression - private def astForFor(forStmt: IASTForStatement): Ast = - val codeInit = nullSafeCode(forStmt.getInitializerStatement) - val codeCond = nullSafeCode(forStmt.getConditionExpression) - val codeIter = nullSafeCode(forStmt.getIterationExpression) + private def astForFor(forStmt: IASTForStatement): Ast = + val codeInit = nullSafeCode(forStmt.getInitializerStatement) + val codeCond = nullSafeCode(forStmt.getConditionExpression) + val codeIter = nullSafeCode(forStmt.getIterationExpression) - val code = s"for ($codeInit$codeCond;$codeIter)" - val forNode = controlStructureNode(forStmt, ControlStructureTypes.FOR, code) + val code = s"for ($codeInit$codeCond;$codeIter)" + val forNode = controlStructureNode(forStmt, ControlStructureTypes.FOR, code) - val initAstBlock = blockNode(forStmt, Defines.empty, registerType(Defines.voidTypeName)) - scope.pushNewScope(initAstBlock) - val initAst = blockAst(initAstBlock, nullSafeAst(forStmt.getInitializerStatement, 1).toList) - scope.popScope() + val initAstBlock = blockNode(forStmt, Defines.empty, registerType(Defines.voidTypeName)) + scope.pushNewScope(initAstBlock) + val initAst = blockAst(initAstBlock, nullSafeAst(forStmt.getInitializerStatement, 1).toList) + scope.popScope() - val compareAst = astForConditionExpression(forStmt.getConditionExpression, Some(2)) - val updateAst = nullSafeAst(forStmt.getIterationExpression, 3) - val bodyAsts = nullSafeAst(forStmt.getBody, 4) - forAst(forNode, Seq(), Seq(initAst), Seq(compareAst), Seq(updateAst), bodyAsts) + val compareAst = astForConditionExpression(forStmt.getConditionExpression, Some(2)) + val updateAst = nullSafeAst(forStmt.getIterationExpression, 3) + val bodyAsts = nullSafeAst(forStmt.getBody, 4) + forAst(forNode, Seq(), Seq(initAst), Seq(compareAst), Seq(updateAst), bodyAsts) - private def astForRangedFor(forStmt: ICPPASTRangeBasedForStatement): Ast = - val codeDecl = nullSafeCode(forStmt.getDeclaration) - val codeInit = nullSafeCode(forStmt.getInitializerClause) + private def astForRangedFor(forStmt: ICPPASTRangeBasedForStatement): Ast = + val codeDecl = nullSafeCode(forStmt.getDeclaration) + val codeInit = nullSafeCode(forStmt.getInitializerClause) - val code = s"for ($codeDecl:$codeInit)" - val forNode = controlStructureNode(forStmt, ControlStructureTypes.FOR, code) + val code = s"for ($codeDecl:$codeInit)" + val forNode = controlStructureNode(forStmt, ControlStructureTypes.FOR, code) - val initAst = astForNode(forStmt.getInitializerClause) - val declAst = astsForDeclaration(forStmt.getDeclaration) - val stmtAst = nullSafeAst(forStmt.getBody) - controlStructureAst(forNode, None, Seq(initAst) ++ declAst ++ stmtAst) + val initAst = astForNode(forStmt.getInitializerClause) + val declAst = astsForDeclaration(forStmt.getDeclaration) + val stmtAst = nullSafeAst(forStmt.getBody) + controlStructureAst(forNode, None, Seq(initAst) ++ declAst ++ stmtAst) - private def astForWhile(whileStmt: IASTWhileStatement): Ast = - val code = s"while (${nullSafeCode(whileStmt.getCondition)})" - val compareAst = astForConditionExpression(whileStmt.getCondition) - val bodyAst = nullSafeAst(whileStmt.getBody) - whileAst( - Some(compareAst), - bodyAst, - Some(code), - lineNumber = line(whileStmt), - columnNumber = column(whileStmt) - ) + private def astForWhile(whileStmt: IASTWhileStatement): Ast = + val code = s"while (${nullSafeCode(whileStmt.getCondition)})" + val compareAst = astForConditionExpression(whileStmt.getCondition) + val bodyAst = nullSafeAst(whileStmt.getBody) + whileAst( + Some(compareAst), + bodyAst, + Some(code), + lineNumber = line(whileStmt), + columnNumber = column(whileStmt) + ) - private def astForIf(ifStmt: IASTIfStatement): Ast = - val (code, conditionAst) = ifStmt match - case s @ (_: CASTIfStatement | _: CPPASTIfStatement) - if s.getConditionExpression != null => - val c = s"if (${nullSafeCode(s.getConditionExpression)})" - val compareAst = astForConditionExpression(s.getConditionExpression) - (c, compareAst) - case s: CPPASTIfStatement if s.getConditionExpression == null => - val c = s"if (${nullSafeCode(s.getConditionDeclaration)})" - val exprBlock = - blockNode(s.getConditionDeclaration, Defines.empty, Defines.voidTypeName) - scope.pushNewScope(exprBlock) - val a = astsForDeclaration(s.getConditionDeclaration) - setArgumentIndices(a) - scope.popScope() - (c, blockAst(exprBlock, a.toList)) + private def astForIf(ifStmt: IASTIfStatement): Ast = + val (code, conditionAst) = ifStmt match + case s @ (_: CASTIfStatement | _: CPPASTIfStatement) + if s.getConditionExpression != null => + val c = s"if (${nullSafeCode(s.getConditionExpression)})" + val compareAst = astForConditionExpression(s.getConditionExpression) + (c, compareAst) + case s: CPPASTIfStatement if s.getConditionExpression == null => + val c = s"if (${nullSafeCode(s.getConditionDeclaration)})" + val exprBlock = + blockNode(s.getConditionDeclaration, Defines.empty, Defines.voidTypeName) + scope.pushNewScope(exprBlock) + val a = astsForDeclaration(s.getConditionDeclaration) + setArgumentIndices(a) + scope.popScope() + (c, blockAst(exprBlock, a.toList)) - val ifNode = controlStructureNode(ifStmt, ControlStructureTypes.IF, code) + val ifNode = controlStructureNode(ifStmt, ControlStructureTypes.IF, code) - val thenAst = ifStmt.getThenClause match - case block: IASTCompoundStatement => astForBlockStatement(block) - case other if other != null => - val thenBlock = blockNode(other, Defines.empty, Defines.voidTypeName) - scope.pushNewScope(thenBlock) - val a = astsForStatement(other) - setArgumentIndices(a) - scope.popScope() - blockAst(thenBlock, a.toList) - case _ => Ast() + val thenAst = ifStmt.getThenClause match + case block: IASTCompoundStatement => astForBlockStatement(block) + case other if other != null => + val thenBlock = blockNode(other, Defines.empty, Defines.voidTypeName) + scope.pushNewScope(thenBlock) + val a = astsForStatement(other) + setArgumentIndices(a) + scope.popScope() + blockAst(thenBlock, a.toList) + case _ => Ast() - val elseAst = ifStmt.getElseClause match - case block: IASTCompoundStatement => - val elseNode = - controlStructureNode(ifStmt.getElseClause, ControlStructureTypes.ELSE, "else") - val elseAst = astForBlockStatement(block) - Ast(elseNode).withChild(elseAst) - case other if other != null => - val elseNode = - controlStructureNode(ifStmt.getElseClause, ControlStructureTypes.ELSE, "else") - val elseBlock = blockNode(other, Defines.empty, Defines.voidTypeName) - scope.pushNewScope(elseBlock) - val a = astsForStatement(other) - setArgumentIndices(a) - scope.popScope() - Ast(elseNode).withChild(blockAst(elseBlock, a.toList)) - case _ => Ast() - controlStructureAst(ifNode, Some(conditionAst), Seq(thenAst, elseAst)) - end astForIf + val elseAst = ifStmt.getElseClause match + case block: IASTCompoundStatement => + val elseNode = + controlStructureNode(ifStmt.getElseClause, ControlStructureTypes.ELSE, "else") + val elseAst = astForBlockStatement(block) + Ast(elseNode).withChild(elseAst) + case other if other != null => + val elseNode = + controlStructureNode(ifStmt.getElseClause, ControlStructureTypes.ELSE, "else") + val elseBlock = blockNode(other, Defines.empty, Defines.voidTypeName) + scope.pushNewScope(elseBlock) + val a = astsForStatement(other) + setArgumentIndices(a) + scope.popScope() + Ast(elseNode).withChild(blockAst(elseBlock, a.toList)) + case _ => Ast() + controlStructureAst(ifNode, Some(conditionAst), Seq(thenAst, elseAst)) + end astForIf end AstForStatementsCreator diff --git a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/astcreation/AstForTypesCreator.scala b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/astcreation/AstForTypesCreator.scala index f53b24f4..b3d3181e 100644 --- a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/astcreation/AstForTypesCreator.scala +++ b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/astcreation/AstForTypesCreator.scala @@ -10,468 +10,468 @@ import org.eclipse.cdt.internal.core.model.ASTStringUtil import io.appthreat.x2cpg.datastructures.Stack.* trait AstForTypesCreator(implicit withSchemaValidation: ValidationMode): - this: AstCreator => - - private def parentIsClassDef(node: IASTNode): Boolean = Option(node.getParent) match - case Some(_: IASTCompositeTypeSpecifier) => true - case _ => false - - private def isTypeDef(decl: IASTSimpleDeclaration): Boolean = - nodeSignature(decl).startsWith("typedef") - - protected def templateParameters(e: IASTNode): Option[String] = - val templateDeclaration = e match - case _: IASTElaboratedTypeSpecifier | _: IASTFunctionDeclarator | _: IASTCompositeTypeSpecifier - if e.getParent != null => - Option(e.getParent.getParent) - case _: IASTFunctionDefinition if e.getParent != null => Option(e.getParent) - case _ => None - - val decl = templateDeclaration.collect { case t: ICPPASTTemplateDeclaration => t } - val templateParams = - decl.map(d => ASTStringUtil.getTemplateParameterArray(d.getTemplateParameters)) - templateParams.map(_.mkString("<", ",", ">")) - - private def astForNamespaceDefinition(namespaceDefinition: ICPPASTNamespaceDefinition): Ast = - val (name, fullname) = - uniqueName( - "namespace", - namespaceDefinition.getName.getLastName.toString, - fullName(namespaceDefinition) - ) - val code = nodeSignature(namespaceDefinition) - val cpgNamespace = newNamespaceBlockNode( - namespaceDefinition, - name, - fullname, - code, - fileName(namespaceDefinition) + this: AstCreator => + + private def parentIsClassDef(node: IASTNode): Boolean = Option(node.getParent) match + case Some(_: IASTCompositeTypeSpecifier) => true + case _ => false + + private def isTypeDef(decl: IASTSimpleDeclaration): Boolean = + nodeSignature(decl).startsWith("typedef") + + protected def templateParameters(e: IASTNode): Option[String] = + val templateDeclaration = e match + case _: IASTElaboratedTypeSpecifier | _: IASTFunctionDeclarator | _: IASTCompositeTypeSpecifier + if e.getParent != null => + Option(e.getParent.getParent) + case _: IASTFunctionDefinition if e.getParent != null => Option(e.getParent) + case _ => None + + val decl = templateDeclaration.collect { case t: ICPPASTTemplateDeclaration => t } + val templateParams = + decl.map(d => ASTStringUtil.getTemplateParameterArray(d.getTemplateParameters)) + templateParams.map(_.mkString("<", ",", ">")) + + private def astForNamespaceDefinition(namespaceDefinition: ICPPASTNamespaceDefinition): Ast = + val (name, fullname) = + uniqueName( + "namespace", + namespaceDefinition.getName.getLastName.toString, + fullName(namespaceDefinition) ) - scope.pushNewScope(cpgNamespace) - - val childrenAsts = namespaceDefinition.getDeclarations.flatMap { decl => - val declAsts = astsForDeclaration(decl) - declAsts - }.toIndexedSeq - - val namespaceAst = Ast(cpgNamespace).withChildren(childrenAsts) - scope.popScope() - namespaceAst - end astForNamespaceDefinition - - protected def astForNamespaceAlias(namespaceAlias: ICPPASTNamespaceAlias): Ast = - val name = ASTStringUtil.getSimpleName(namespaceAlias.getAlias) - val fullname = fullName(namespaceAlias) - - if !isQualifiedName(name) then - usingDeclarationMappings.put(name, fullname) - - val code = nodeSignature(namespaceAlias) - val cpgNamespace = - newNamespaceBlockNode(namespaceAlias, name, fullname, code, fileName(namespaceAlias)) - Ast(cpgNamespace) - - protected def astForDeclarator( - declaration: IASTSimpleDeclaration, - declarator: IASTDeclarator, - index: Int - ): Ast = - val name = ASTStringUtil.getSimpleName(declarator.getName) - declaration match - case d if isTypeDef(d) && shortName(d.getDeclSpecifier).nonEmpty => - val filename = fileName(declaration) - val tpe = registerType(typeFor(declarator)) - Ast(typeDeclNode( - declarator, - name, - registerType(name), - filename, - nodeSignature(d), - alias = Option(tpe) - )) - case d if parentIsClassDef(d) => - val tpe = declarator match - case _: IASTArrayDeclarator => registerType(typeFor(declarator)) - case _ => registerType(typeForDeclSpecifier(declaration.getDeclSpecifier)) - Ast(memberNode(declarator, name, nodeSignature(declarator), tpe)) - case _ if declarator.isInstanceOf[IASTArrayDeclarator] => - val tpe = registerType(typeFor(declarator)) - val codeTpe = typeFor(declarator, stripKeywords = false) - val node = localNode(declarator, name, s"$codeTpe $name", tpe) - scope.addToScope(name, (node, tpe)) - Ast(node) - case _ => - val tpe = registerType( - cleanType(typeForDeclSpecifier( - declaration.getDeclSpecifier, - stripKeywords = true, - index - )) - ) - val codeTpe = - typeForDeclSpecifier(declaration.getDeclSpecifier, stripKeywords = false, index) - val node = localNode(declarator, name, s"$codeTpe $name", tpe) - scope.addToScope(name, (node, tpe)) - Ast(node) - end match - end astForDeclarator - - protected def astForInitializer(declarator: IASTDeclarator, init: IASTInitializer): Ast = - init match - case i: IASTEqualsInitializer => - val operatorName = Operators.assignment - val left = astForNode(declarator.getName) - val right = astForNode(i.getInitializerClause) - val code = i.getInitializerClause.getRawSignature; - val dispatchType = - if code.nonEmpty && (code.startsWith("&") || code.contains("->")) then - DispatchTypes.DYNAMIC_DISPATCH - else DispatchTypes.STATIC_DISPATCH - val callNode_ = - callNode( - declarator, - nodeSignature(declarator), - operatorName, - operatorName, - dispatchType - ) - callAst(callNode_, List(left, right)) - case i: ICPPASTConstructorInitializer => - val name = ASTStringUtil.getSimpleName(declarator.getName) - val callNode_ = callNode( + val code = nodeSignature(namespaceDefinition) + val cpgNamespace = newNamespaceBlockNode( + namespaceDefinition, + name, + fullname, + code, + fileName(namespaceDefinition) + ) + scope.pushNewScope(cpgNamespace) + + val childrenAsts = namespaceDefinition.getDeclarations.flatMap { decl => + val declAsts = astsForDeclaration(decl) + declAsts + }.toIndexedSeq + + val namespaceAst = Ast(cpgNamespace).withChildren(childrenAsts) + scope.popScope() + namespaceAst + end astForNamespaceDefinition + + protected def astForNamespaceAlias(namespaceAlias: ICPPASTNamespaceAlias): Ast = + val name = ASTStringUtil.getSimpleName(namespaceAlias.getAlias) + val fullname = fullName(namespaceAlias) + + if !isQualifiedName(name) then + usingDeclarationMappings.put(name, fullname) + + val code = nodeSignature(namespaceAlias) + val cpgNamespace = + newNamespaceBlockNode(namespaceAlias, name, fullname, code, fileName(namespaceAlias)) + Ast(cpgNamespace) + + protected def astForDeclarator( + declaration: IASTSimpleDeclaration, + declarator: IASTDeclarator, + index: Int + ): Ast = + val name = ASTStringUtil.getSimpleName(declarator.getName) + declaration match + case d if isTypeDef(d) && shortName(d.getDeclSpecifier).nonEmpty => + val filename = fileName(declaration) + val tpe = registerType(typeFor(declarator)) + Ast(typeDeclNode( + declarator, + name, + registerType(name), + filename, + nodeSignature(d), + alias = Option(tpe) + )) + case d if parentIsClassDef(d) => + val tpe = declarator match + case _: IASTArrayDeclarator => registerType(typeFor(declarator)) + case _ => registerType(typeForDeclSpecifier(declaration.getDeclSpecifier)) + Ast(memberNode(declarator, name, nodeSignature(declarator), tpe)) + case _ if declarator.isInstanceOf[IASTArrayDeclarator] => + val tpe = registerType(typeFor(declarator)) + val codeTpe = typeFor(declarator, stripKeywords = false) + val node = localNode(declarator, name, s"$codeTpe $name", tpe) + scope.addToScope(name, (node, tpe)) + Ast(node) + case _ => + val tpe = registerType( + cleanType(typeForDeclSpecifier( + declaration.getDeclSpecifier, + stripKeywords = true, + index + )) + ) + val codeTpe = + typeForDeclSpecifier(declaration.getDeclSpecifier, stripKeywords = false, index) + val node = localNode(declarator, name, s"$codeTpe $name", tpe) + scope.addToScope(name, (node, tpe)) + Ast(node) + end match + end astForDeclarator + + protected def astForInitializer(declarator: IASTDeclarator, init: IASTInitializer): Ast = + init match + case i: IASTEqualsInitializer => + val operatorName = Operators.assignment + val left = astForNode(declarator.getName) + val right = astForNode(i.getInitializerClause) + val code = i.getInitializerClause.getRawSignature; + val dispatchType = + if code.nonEmpty && (code.startsWith("&") || code.contains("->")) then + DispatchTypes.DYNAMIC_DISPATCH + else DispatchTypes.STATIC_DISPATCH + val callNode_ = + callNode( declarator, nodeSignature(declarator), - name, - name, - DispatchTypes.STATIC_DISPATCH + operatorName, + operatorName, + dispatchType ) - val args = i.getArguments.toList.map(x => astForNode(x)) - callAst(callNode_, args) - case i: IASTInitializerList => - val operatorName = Operators.assignment - val callNode_ = - callNode( - declarator, - nodeSignature(declarator), - operatorName, - operatorName, - DispatchTypes.STATIC_DISPATCH - ) - val left = astForNode(declarator.getName) - val right = astForNode(i) - callAst(callNode_, List(left, right)) - case _ => astForNode(init) - - protected def handleUsingDeclaration(usingDecl: ICPPASTUsingDeclaration): Seq[Ast] = - val simpleName = ASTStringUtil.getSimpleName(usingDecl.getName) - val mappedName = lastNameOfQualifiedName(simpleName) - // we only do the mapping if the declaration is not global because this is already handled by the parser itself - if !isQualifiedName(simpleName) then - usingDecl.getParent match - case ns: ICPPASTNamespaceDefinition => - usingDeclarationMappings.put( - s"${fullName(ns)}.$mappedName", - fixQualifiedName(simpleName) - ) - case _ => - usingDeclarationMappings.put(mappedName, fixQualifiedName(simpleName)) - Seq.empty - - protected def astForAliasDeclaration(aliasDeclaration: ICPPASTAliasDeclaration): Ast = - val name = aliasDeclaration.getAlias.toString - val mappedName = registerType(typeFor(aliasDeclaration.getMappingTypeId)) - val typeDeclNode_ = - typeDeclNode( - aliasDeclaration, + callAst(callNode_, List(left, right)) + case i: ICPPASTConstructorInitializer => + val name = ASTStringUtil.getSimpleName(declarator.getName) + val callNode_ = callNode( + declarator, + nodeSignature(declarator), name, - registerType(name), - fileName(aliasDeclaration), - nodeSignature(aliasDeclaration), - alias = Option(mappedName) - ) - Ast(typeDeclNode_) - - protected def astForASMDeclaration(asm: IASTASMDeclaration): Ast = - Ast(unknownNode(asm, nodeSignature(asm))) - - private def astForStructuredBindingDeclaration(decl: ICPPASTStructuredBindingDeclaration): Ast = - val node = blockNode(decl, Defines.empty, Defines.voidTypeName) - scope.pushNewScope(node) - val childAsts = decl.getNames.toList.map { name => - astForNode(name) - } - scope.popScope() - setArgumentIndices(childAsts) - blockAst(node, childAsts) - - protected def astsForDeclaration(decl: IASTDeclaration): Seq[Ast] = - val declAsts = decl match - case sb: ICPPASTStructuredBindingDeclaration => - Seq(astForStructuredBindingDeclaration(sb)) - case declaration: IASTSimpleDeclaration => - declaration.getDeclSpecifier match - case spec: IASTCompositeTypeSpecifier => - astsForCompositeType(spec, declaration.getDeclarators.toList) - case spec: IASTEnumerationSpecifier => - astsForEnum(spec, declaration.getDeclarators.toList) - case spec: IASTElaboratedTypeSpecifier => - astsForElaboratedType(spec, declaration.getDeclarators.toList) - case spec: IASTNamedTypeSpecifier if declaration.getDeclarators.isEmpty => - val filename = fileName(spec) - val name = ASTStringUtil.getSimpleName(spec.getName) - Seq(Ast(typeDeclNode( - spec, - name, - registerType(name), - filename, - nodeSignature(spec), - alias = Option(name) - ))) - case _ if declaration.getDeclarators.nonEmpty => - declaration.getDeclarators.toIndexedSeq.zipWithIndex.map { - case (d: IASTFunctionDeclarator, _) => - astForFunctionDeclarator(d) - case (d: IASTSimpleDeclaration, _) if d.getInitializer != null => - Ast() // we do the AST for this down below with initAsts - case (d, i) => - astForDeclarator(declaration, d, i) - } - case _ if nodeSignature(declaration) == ";" => - Seq.empty // dangling decls from unresolved macros; we ignore them - case _ - if declaration.getDeclarators.isEmpty && declaration.getParent.isInstanceOf[ - IASTTranslationUnit - ] => - Seq.empty // dangling decls from unresolved macros; we ignore them - case _ if declaration.getDeclarators.isEmpty => Seq(astForNode(declaration)) - case alias: CPPASTAliasDeclaration => Seq(astForAliasDeclaration(alias)) - case functDef: IASTFunctionDefinition => Seq(astForFunctionDefinition(functDef)) - case namespaceAlias: ICPPASTNamespaceAlias => Seq(astForNamespaceAlias(namespaceAlias)) - case namespaceDefinition: ICPPASTNamespaceDefinition => - Seq(astForNamespaceDefinition(namespaceDefinition)) - case a: ICPPASTStaticAssertDeclaration => Seq(astForStaticAssert(a)) - case asm: IASTASMDeclaration => Seq(astForASMDeclaration(asm)) - case t: ICPPASTTemplateDeclaration => astsForDeclaration(t.getDeclaration) - case l: ICPPASTLinkageSpecification => astsForLinkageSpecification(l) - case u: ICPPASTUsingDeclaration => handleUsingDeclaration(u) - case _: ICPPASTVisibilityLabel => Seq.empty - case _: ICPPASTUsingDirective => Seq.empty - case _: ICPPASTExplicitTemplateInstantiation => Seq.empty - case _ => Seq(astForNode(decl)) - - val initAsts = decl match - case declaration: IASTSimpleDeclaration if declaration.getDeclarators.nonEmpty => - declaration.getDeclarators.toList.map { - case d: IASTDeclarator if d.getInitializer != null => - astForInitializer(d, d.getInitializer) - case arrayDecl: IASTArrayDeclarator => - val op = Operators.arrayInitializer - val initCallNode = callNode( - arrayDecl, - nodeSignature(arrayDecl), - op, - op, - DispatchTypes.STATIC_DISPATCH - ) - val initArgs = - arrayDecl.getArrayModifiers.toList.filter(m => - m.getConstantExpression != null - ).map(astForNode) - callAst(initCallNode, initArgs) - case _ => Ast() - } - case _ => Nil - declAsts ++ initAsts - end astsForDeclaration - - private def astsForLinkageSpecification(l: ICPPASTLinkageSpecification): Seq[Ast] = - l.getDeclarations.toList.flatMap { d => - astsForDeclaration(d) - } - - private def astsForCompositeType( - typeSpecifier: IASTCompositeTypeSpecifier, - decls: List[IASTDeclarator] - ): Seq[Ast] = - val filename = fileName(typeSpecifier) - val declAsts = decls.zipWithIndex.map { case (d, i) => - astForDeclarator(typeSpecifier.getParent.asInstanceOf[IASTSimpleDeclaration], d, i) - } - - val lineNumber = line(typeSpecifier) - val columnNumber = column(typeSpecifier) - val fullname = registerType(cleanType(fullName(typeSpecifier))) - val name = ASTStringUtil.getSimpleName(typeSpecifier.getName) match - case n if n.isEmpty => lastNameOfQualifiedName(fullname) - case other => other - val code = nodeSignature(typeSpecifier) - val nameAlias = decls.headOption.map(d => registerType(shortName(d))).filter(_.nonEmpty) - val nameWithTemplateParams = - templateParameters(typeSpecifier).map(t => registerType(s"$fullname$t")) - val alias = (nameAlias.toList ++ nameWithTemplateParams.toList).headOption - - val typeDecl = typeSpecifier match - case cppClass: ICPPASTCompositeTypeSpecifier => - val baseClassList = - cppClass.getBaseSpecifiers.toSeq.map(s => - registerType(s.getNameSpecifier.toString) - ) - typeDeclNode( - typeSpecifier, - name, - fullname, - filename, - code, - inherits = baseClassList, - alias = alias - ) - case _ => - typeDeclNode(typeSpecifier, name, fullname, filename, code, alias = alias) - - methodAstParentStack.push(typeDecl) - scope.pushNewScope(typeDecl) - - val memberAsts = typeSpecifier.getDeclarations(true).toList.flatMap(astsForDeclaration) - - methodAstParentStack.pop() - scope.popScope() - - val (calls, member) = - memberAsts.partition(_.nodes.headOption.exists(_.isInstanceOf[NewCall])) - if calls.isEmpty then - Ast(typeDecl).withChildren(member) +: declAsts - else - val init = staticInitMethodAst( - calls, - s"$fullname:${io.appthreat.x2cpg.Defines.StaticInitMethodName}", - None, - Defines.anyTypeName, - Some(filename), - lineNumber, - columnNumber - ) - Ast(typeDecl).withChildren(member).withChild(init) +: declAsts - end astsForCompositeType - - private def astsForElaboratedType( - typeSpecifier: IASTElaboratedTypeSpecifier, - decls: List[IASTDeclarator] - ): Seq[Ast] = - val filename = fileName(typeSpecifier) - val declAsts = decls.zipWithIndex.map { case (d, i) => - astForDeclarator(typeSpecifier.getParent.asInstanceOf[IASTSimpleDeclaration], d, i) - } - - val name = ASTStringUtil.getSimpleName(typeSpecifier.getName) - val fullname = registerType(cleanType(fullName(typeSpecifier))) - val nameAlias = decls.headOption.map(d => registerType(shortName(d))).filter(_.nonEmpty) - val nameWithTemplateParams = - templateParameters(typeSpecifier).map(t => registerType(s"$fullname$t")) - val alias = (nameAlias.toList ++ nameWithTemplateParams.toList).headOption - - val typeDecl = - typeDeclNode( - typeSpecifier, name, - fullname, - filename, - nodeSignature(typeSpecifier), - alias = alias + DispatchTypes.STATIC_DISPATCH ) - - Ast(typeDecl) +: declAsts - end astsForElaboratedType - - private def astsForEnumerator(enumerator: IASTEnumerationSpecifier.IASTEnumerator): Seq[Ast] = - val tpe = enumerator.getParent match - case enumeration: ICPPASTEnumerationSpecifier if enumeration.getBaseType != null => - enumeration.getBaseType.toString - case _ => typeFor(enumerator) - val cpgMember = memberNode( - enumerator, - ASTStringUtil.getSimpleName(enumerator.getName), - nodeSignature(enumerator), - registerType(cleanType(tpe)) - ) - - if enumerator.getValue != null then + val args = i.getArguments.toList.map(x => astForNode(x)) + callAst(callNode_, args) + case i: IASTInitializerList => val operatorName = Operators.assignment val callNode_ = callNode( - enumerator, - nodeSignature(enumerator), + declarator, + nodeSignature(declarator), operatorName, operatorName, DispatchTypes.STATIC_DISPATCH ) - val left = astForNode(enumerator.getName) - val right = astForNode(enumerator.getValue) - val ast = callAst(callNode_, List(left, right)) - Seq(Ast(cpgMember), ast) - else - Seq(Ast(cpgMember)) - end astsForEnumerator - - private def astsForEnum( - typeSpecifier: IASTEnumerationSpecifier, - decls: List[IASTDeclarator] - ): Seq[Ast] = - val filename = fileName(typeSpecifier) - val declAsts = decls.zipWithIndex.map { case (d, i) => - astForDeclarator(typeSpecifier.getParent.asInstanceOf[IASTSimpleDeclaration], d, i) - } - - val lineNumber = line(typeSpecifier) - val columnNumber = column(typeSpecifier) - val (name, fullname) = - uniqueName( - "enum", - ASTStringUtil.getSimpleName(typeSpecifier.getName), - fullName(typeSpecifier) + val left = astForNode(declarator.getName) + val right = astForNode(i) + callAst(callNode_, List(left, right)) + case _ => astForNode(init) + + protected def handleUsingDeclaration(usingDecl: ICPPASTUsingDeclaration): Seq[Ast] = + val simpleName = ASTStringUtil.getSimpleName(usingDecl.getName) + val mappedName = lastNameOfQualifiedName(simpleName) + // we only do the mapping if the declaration is not global because this is already handled by the parser itself + if !isQualifiedName(simpleName) then + usingDecl.getParent match + case ns: ICPPASTNamespaceDefinition => + usingDeclarationMappings.put( + s"${fullName(ns)}.$mappedName", + fixQualifiedName(simpleName) ) - val alias = decls.headOption.map(d => registerType(shortName(d))).filter(_.nonEmpty) - - val (deAliasedName, deAliasedFullName, newAlias) = - if name.contains("anonymous_enum") && alias.isDefined then - ( - alias.get, - fullname.substring(0, fullname.indexOf("anonymous_enum")) + alias.get, - None - ) - else (name, fullname, alias) - - val typeDecl = - typeDeclNode( - typeSpecifier, - deAliasedName, - registerType(deAliasedFullName), - filename, - nodeSignature(typeSpecifier), - alias = newAlias - ) - methodAstParentStack.push(typeDecl) - scope.pushNewScope(typeDecl) - - val memberAsts = typeSpecifier.getEnumerators.toList.flatMap { e => - astsForEnumerator(e) - } - methodAstParentStack.pop() - scope.popScope() - - val (calls, member) = - memberAsts.partition(_.nodes.headOption.exists(_.isInstanceOf[NewCall])) - if calls.isEmpty then - Ast(typeDecl).withChildren(member) +: declAsts - else - val init = staticInitMethodAst( - calls, - s"$deAliasedFullName:${io.appthreat.x2cpg.Defines.StaticInitMethodName}", - None, - Defines.anyTypeName, - Some(filename), - lineNumber, - columnNumber - ) - Ast(typeDecl).withChildren(member).withChild(init) +: declAsts - end astsForEnum + case _ => + usingDeclarationMappings.put(mappedName, fixQualifiedName(simpleName)) + Seq.empty + + protected def astForAliasDeclaration(aliasDeclaration: ICPPASTAliasDeclaration): Ast = + val name = aliasDeclaration.getAlias.toString + val mappedName = registerType(typeFor(aliasDeclaration.getMappingTypeId)) + val typeDeclNode_ = + typeDeclNode( + aliasDeclaration, + name, + registerType(name), + fileName(aliasDeclaration), + nodeSignature(aliasDeclaration), + alias = Option(mappedName) + ) + Ast(typeDeclNode_) + + protected def astForASMDeclaration(asm: IASTASMDeclaration): Ast = + Ast(unknownNode(asm, nodeSignature(asm))) + + private def astForStructuredBindingDeclaration(decl: ICPPASTStructuredBindingDeclaration): Ast = + val node = blockNode(decl, Defines.empty, Defines.voidTypeName) + scope.pushNewScope(node) + val childAsts = decl.getNames.toList.map { name => + astForNode(name) + } + scope.popScope() + setArgumentIndices(childAsts) + blockAst(node, childAsts) + + protected def astsForDeclaration(decl: IASTDeclaration): Seq[Ast] = + val declAsts = decl match + case sb: ICPPASTStructuredBindingDeclaration => + Seq(astForStructuredBindingDeclaration(sb)) + case declaration: IASTSimpleDeclaration => + declaration.getDeclSpecifier match + case spec: IASTCompositeTypeSpecifier => + astsForCompositeType(spec, declaration.getDeclarators.toList) + case spec: IASTEnumerationSpecifier => + astsForEnum(spec, declaration.getDeclarators.toList) + case spec: IASTElaboratedTypeSpecifier => + astsForElaboratedType(spec, declaration.getDeclarators.toList) + case spec: IASTNamedTypeSpecifier if declaration.getDeclarators.isEmpty => + val filename = fileName(spec) + val name = ASTStringUtil.getSimpleName(spec.getName) + Seq(Ast(typeDeclNode( + spec, + name, + registerType(name), + filename, + nodeSignature(spec), + alias = Option(name) + ))) + case _ if declaration.getDeclarators.nonEmpty => + declaration.getDeclarators.toIndexedSeq.zipWithIndex.map { + case (d: IASTFunctionDeclarator, _) => + astForFunctionDeclarator(d) + case (d: IASTSimpleDeclaration, _) if d.getInitializer != null => + Ast() // we do the AST for this down below with initAsts + case (d, i) => + astForDeclarator(declaration, d, i) + } + case _ if nodeSignature(declaration) == ";" => + Seq.empty // dangling decls from unresolved macros; we ignore them + case _ + if declaration.getDeclarators.isEmpty && declaration.getParent.isInstanceOf[ + IASTTranslationUnit + ] => + Seq.empty // dangling decls from unresolved macros; we ignore them + case _ if declaration.getDeclarators.isEmpty => Seq(astForNode(declaration)) + case alias: CPPASTAliasDeclaration => Seq(astForAliasDeclaration(alias)) + case functDef: IASTFunctionDefinition => Seq(astForFunctionDefinition(functDef)) + case namespaceAlias: ICPPASTNamespaceAlias => Seq(astForNamespaceAlias(namespaceAlias)) + case namespaceDefinition: ICPPASTNamespaceDefinition => + Seq(astForNamespaceDefinition(namespaceDefinition)) + case a: ICPPASTStaticAssertDeclaration => Seq(astForStaticAssert(a)) + case asm: IASTASMDeclaration => Seq(astForASMDeclaration(asm)) + case t: ICPPASTTemplateDeclaration => astsForDeclaration(t.getDeclaration) + case l: ICPPASTLinkageSpecification => astsForLinkageSpecification(l) + case u: ICPPASTUsingDeclaration => handleUsingDeclaration(u) + case _: ICPPASTVisibilityLabel => Seq.empty + case _: ICPPASTUsingDirective => Seq.empty + case _: ICPPASTExplicitTemplateInstantiation => Seq.empty + case _ => Seq(astForNode(decl)) + + val initAsts = decl match + case declaration: IASTSimpleDeclaration if declaration.getDeclarators.nonEmpty => + declaration.getDeclarators.toList.map { + case d: IASTDeclarator if d.getInitializer != null => + astForInitializer(d, d.getInitializer) + case arrayDecl: IASTArrayDeclarator => + val op = Operators.arrayInitializer + val initCallNode = callNode( + arrayDecl, + nodeSignature(arrayDecl), + op, + op, + DispatchTypes.STATIC_DISPATCH + ) + val initArgs = + arrayDecl.getArrayModifiers.toList.filter(m => + m.getConstantExpression != null + ).map(astForNode) + callAst(initCallNode, initArgs) + case _ => Ast() + } + case _ => Nil + declAsts ++ initAsts + end astsForDeclaration + + private def astsForLinkageSpecification(l: ICPPASTLinkageSpecification): Seq[Ast] = + l.getDeclarations.toList.flatMap { d => + astsForDeclaration(d) + } + + private def astsForCompositeType( + typeSpecifier: IASTCompositeTypeSpecifier, + decls: List[IASTDeclarator] + ): Seq[Ast] = + val filename = fileName(typeSpecifier) + val declAsts = decls.zipWithIndex.map { case (d, i) => + astForDeclarator(typeSpecifier.getParent.asInstanceOf[IASTSimpleDeclaration], d, i) + } + + val lineNumber = line(typeSpecifier) + val columnNumber = column(typeSpecifier) + val fullname = registerType(cleanType(fullName(typeSpecifier))) + val name = ASTStringUtil.getSimpleName(typeSpecifier.getName) match + case n if n.isEmpty => lastNameOfQualifiedName(fullname) + case other => other + val code = nodeSignature(typeSpecifier) + val nameAlias = decls.headOption.map(d => registerType(shortName(d))).filter(_.nonEmpty) + val nameWithTemplateParams = + templateParameters(typeSpecifier).map(t => registerType(s"$fullname$t")) + val alias = (nameAlias.toList ++ nameWithTemplateParams.toList).headOption + + val typeDecl = typeSpecifier match + case cppClass: ICPPASTCompositeTypeSpecifier => + val baseClassList = + cppClass.getBaseSpecifiers.toSeq.map(s => + registerType(s.getNameSpecifier.toString) + ) + typeDeclNode( + typeSpecifier, + name, + fullname, + filename, + code, + inherits = baseClassList, + alias = alias + ) + case _ => + typeDeclNode(typeSpecifier, name, fullname, filename, code, alias = alias) + + methodAstParentStack.push(typeDecl) + scope.pushNewScope(typeDecl) + + val memberAsts = typeSpecifier.getDeclarations(true).toList.flatMap(astsForDeclaration) + + methodAstParentStack.pop() + scope.popScope() + + val (calls, member) = + memberAsts.partition(_.nodes.headOption.exists(_.isInstanceOf[NewCall])) + if calls.isEmpty then + Ast(typeDecl).withChildren(member) +: declAsts + else + val init = staticInitMethodAst( + calls, + s"$fullname.${io.appthreat.x2cpg.Defines.StaticInitMethodName}", + None, + Defines.anyTypeName, + Some(filename), + lineNumber, + columnNumber + ) + Ast(typeDecl).withChildren(member).withChild(init) +: declAsts + end astsForCompositeType + + private def astsForElaboratedType( + typeSpecifier: IASTElaboratedTypeSpecifier, + decls: List[IASTDeclarator] + ): Seq[Ast] = + val filename = fileName(typeSpecifier) + val declAsts = decls.zipWithIndex.map { case (d, i) => + astForDeclarator(typeSpecifier.getParent.asInstanceOf[IASTSimpleDeclaration], d, i) + } + + val name = ASTStringUtil.getSimpleName(typeSpecifier.getName) + val fullname = registerType(cleanType(fullName(typeSpecifier))) + val nameAlias = decls.headOption.map(d => registerType(shortName(d))).filter(_.nonEmpty) + val nameWithTemplateParams = + templateParameters(typeSpecifier).map(t => registerType(s"$fullname$t")) + val alias = (nameAlias.toList ++ nameWithTemplateParams.toList).headOption + + val typeDecl = + typeDeclNode( + typeSpecifier, + name, + fullname, + filename, + nodeSignature(typeSpecifier), + alias = alias + ) + + Ast(typeDecl) +: declAsts + end astsForElaboratedType + + private def astsForEnumerator(enumerator: IASTEnumerationSpecifier.IASTEnumerator): Seq[Ast] = + val tpe = enumerator.getParent match + case enumeration: ICPPASTEnumerationSpecifier if enumeration.getBaseType != null => + enumeration.getBaseType.toString + case _ => typeFor(enumerator) + val cpgMember = memberNode( + enumerator, + ASTStringUtil.getSimpleName(enumerator.getName), + nodeSignature(enumerator), + registerType(cleanType(tpe)) + ) + + if enumerator.getValue != null then + val operatorName = Operators.assignment + val callNode_ = + callNode( + enumerator, + nodeSignature(enumerator), + operatorName, + operatorName, + DispatchTypes.STATIC_DISPATCH + ) + val left = astForNode(enumerator.getName) + val right = astForNode(enumerator.getValue) + val ast = callAst(callNode_, List(left, right)) + Seq(Ast(cpgMember), ast) + else + Seq(Ast(cpgMember)) + end astsForEnumerator + + private def astsForEnum( + typeSpecifier: IASTEnumerationSpecifier, + decls: List[IASTDeclarator] + ): Seq[Ast] = + val filename = fileName(typeSpecifier) + val declAsts = decls.zipWithIndex.map { case (d, i) => + astForDeclarator(typeSpecifier.getParent.asInstanceOf[IASTSimpleDeclaration], d, i) + } + + val lineNumber = line(typeSpecifier) + val columnNumber = column(typeSpecifier) + val (name, fullname) = + uniqueName( + "enum", + ASTStringUtil.getSimpleName(typeSpecifier.getName), + fullName(typeSpecifier) + ) + val alias = decls.headOption.map(d => registerType(shortName(d))).filter(_.nonEmpty) + + val (deAliasedName, deAliasedFullName, newAlias) = + if name.contains("anonymous_enum") && alias.isDefined then + ( + alias.get, + fullname.substring(0, fullname.indexOf("anonymous_enum")) + alias.get, + None + ) + else (name, fullname, alias) + + val typeDecl = + typeDeclNode( + typeSpecifier, + deAliasedName, + registerType(deAliasedFullName), + filename, + nodeSignature(typeSpecifier), + alias = newAlias + ) + methodAstParentStack.push(typeDecl) + scope.pushNewScope(typeDecl) + + val memberAsts = typeSpecifier.getEnumerators.toList.flatMap { e => + astsForEnumerator(e) + } + methodAstParentStack.pop() + scope.popScope() + + val (calls, member) = + memberAsts.partition(_.nodes.headOption.exists(_.isInstanceOf[NewCall])) + if calls.isEmpty then + Ast(typeDecl).withChildren(member) +: declAsts + else + val init = staticInitMethodAst( + calls, + s"$deAliasedFullName:${io.appthreat.x2cpg.Defines.StaticInitMethodName}", + None, + Defines.anyTypeName, + Some(filename), + lineNumber, + columnNumber + ) + Ast(typeDecl).withChildren(member).withChild(init) +: declAsts + end astsForEnum end AstForTypesCreator diff --git a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/astcreation/AstNodeBuilder.scala b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/astcreation/AstNodeBuilder.scala index c7be9669..f577af60 100644 --- a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/astcreation/AstNodeBuilder.scala +++ b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/astcreation/AstNodeBuilder.scala @@ -7,39 +7,39 @@ import org.eclipse.cdt.core.dom.ast.IASTPreprocessorIncludeStatement import org.eclipse.cdt.internal.core.model.ASTStringUtil trait AstNodeBuilder: - this: AstCreator => - protected def newCommentNode(node: IASTNode, code: String, filename: String): NewComment = - NewComment().code(code).filename(filename).lineNumber(line(node)).columnNumber(column(node)) + this: AstCreator => + protected def newCommentNode(node: IASTNode, code: String, filename: String): NewComment = + NewComment().code(code).filename(filename).lineNumber(line(node)).columnNumber(column(node)) - protected def newNamespaceBlockNode( - node: IASTNode, - name: String, - fullname: String, - code: String, - filename: String - ): NewNamespaceBlock = - NewNamespaceBlock() - .code(code) - .lineNumber(line(node)) - .columnNumber(column(node)) - .filename(filename) - .name(name) - .fullName(fullname) + protected def newNamespaceBlockNode( + node: IASTNode, + name: String, + fullname: String, + code: String, + filename: String + ): NewNamespaceBlock = + NewNamespaceBlock() + .code(code) + .lineNumber(line(node)) + .columnNumber(column(node)) + .filename(filename) + .name(name) + .fullName(fullname) - // TODO: We should get rid of this method as its being used at multiple places and use it from x2cpg/AstNodeBuilder "methodReturnNode" - protected def newMethodReturnNode(node: IASTNode, typeFullName: String): NewMethodReturn = - newMethodReturnNode_(typeFullName, None, line(node), column(node)) + // TODO: We should get rid of this method as its being used at multiple places and use it from x2cpg/AstNodeBuilder "methodReturnNode" + protected def newMethodReturnNode(node: IASTNode, typeFullName: String): NewMethodReturn = + newMethodReturnNode_(typeFullName, None, line(node), column(node)) - protected def newJumpTargetNode(node: IASTNode): NewJumpTarget = - val code = nodeSignature(node) - val name = node match - case label: IASTLabelStatement => ASTStringUtil.getSimpleName(label.getName) - case _ if code.startsWith("case") => "case" - case _ => "default" - NewJumpTarget() - .parserTypeName(node.getClass.getSimpleName) - .name(name) - .code(code) - .lineNumber(line(node)) - .columnNumber(column(node)) + protected def newJumpTargetNode(node: IASTNode): NewJumpTarget = + val code = nodeSignature(node) + val name = node match + case label: IASTLabelStatement => ASTStringUtil.getSimpleName(label.getName) + case _ if code.startsWith("case") => "case" + case _ => "default" + NewJumpTarget() + .parserTypeName(node.getClass.getSimpleName) + .name(name) + .code(code) + .lineNumber(line(node)) + .columnNumber(column(node)) end AstNodeBuilder diff --git a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/astcreation/Defines.scala b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/astcreation/Defines.scala index 32705c20..2c3c3ac9 100644 --- a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/astcreation/Defines.scala +++ b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/astcreation/Defines.scala @@ -1,7 +1,19 @@ package io.appthreat.c2cpg.astcreation object Defines: - val anyTypeName: String = "ANY" - val voidTypeName: String = "void" - val qualifiedNameSeparator: String = "::" - val empty = "" + val anyTypeName: String = "ANY" + val voidTypeName: String = "void" + val qualifiedNameSeparator: String = "::" + val empty = "" + val operatorPointerCall = ".pointerCall" + val operatorConstructorInitializer = ".constructorInitializer" + val operatorTypeOf = ".typeOf" + val operatorMax = ".max" + val operatorMin = ".min" + val operatorEllipses = ".op_ellipses" + val operatorUnknown = ".unknown" + val operatorCall = "()" + val operatorExpressionList = ".expressionList" + val operatorNew = ".new" + val operatorThrow = ".throw" + val operatorBracketedPrimary = ".bracketedPrimary" diff --git a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/astcreation/MacroHandler.scala b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/astcreation/MacroHandler.scala index 48904848..db0a7d7e 100644 --- a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/astcreation/MacroHandler.scala +++ b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/astcreation/MacroHandler.scala @@ -24,171 +24,171 @@ import scala.annotation.nowarn import scala.collection.mutable trait MacroHandler(implicit withSchemaValidation: ValidationMode): - this: AstCreator => - - private val nodeOffsetMacroPairs: mutable.Stack[(Int, IASTPreprocessorMacroDefinition)] = - mutable.Stack.from( - cdtAst.getNodeLocations.toList - .collect { case exp: IASTMacroExpansionLocation => - (exp.asFileLocation().getNodeOffset, exp.getExpansion.getMacroDefinition) - } - .sortBy(_._1) - ) - - /** For the given node, determine if it is expanded from a macro, and if so, create a Call node - * to represent the macro invocation and attach `ast` as its child. - */ - def asChildOfMacroCall(node: IASTNode, ast: Ast): Ast = - // If a macro in a header file contained a method definition already seen in some - // source file we skipped that during the previous AST creation and returned an empty AST. - if ast.root.isEmpty && isExpandedFromMacro(node) then return ast - // We do nothing for locals only. - if ast.nodes.size == 1 && ast.root.exists(_.isInstanceOf[NewLocal]) then return ast - // Otherwise, we create the synthetic call AST. - val matchingMacro = extractMatchingMacro(node) - val macroCallAst = matchingMacro.map { case (mac, args) => - createMacroCallAst(ast, node, mac, args) - } - macroCallAst match - case Some(callAst) => - val lostLocals = ast.refEdges.collect { case AstEdge(_, dst: NewLocal) => - Ast(dst) - }.toList - val newAst = ast.subTreeCopy(ast.root.get.asInstanceOf[AstNodeNew], argIndex = 1) - // We need to wrap the copied AST as it may contain CPG nodes not being allowed - // to be connected via AST edges under a CALL. E.g., LOCALs but only if its not already a BLOCK. - val childAst = newAst.root match - case Some(_: NewBlock) => - newAst - case _ => - val b = NewBlock().argumentIndex(1).typeFullName( - registerType(Defines.voidTypeName) - ) - blockAst(b, List(newAst)) - callAst.withChildren(lostLocals).withChild(childAst) - case None => ast - end asChildOfMacroCall - - /** For the given node, determine if it is expanded from a macro, and if so, find the first - * matching (offset, macro) pair in nodeOffsetMacroPairs, removing non-matching elements from - * the start of nodeOffsetMacroPairs. Returns (Some(macroDefinition, arguments)) if a macro - * definition matches and None otherwise. - */ - private def extractMatchingMacro(node: IASTNode) - : Option[(IASTPreprocessorMacroDefinition, List[String])] = - val expansionLocations = - expandedFromMacro(node).filterNot(isExpandedFrom(node.getParent, _)) - val nodeOffset = node.getFileLocation.getNodeOffset - var matchingMacro = Option.empty[(IASTPreprocessorMacroDefinition, List[String])] - - expansionLocations.foreach { macroLocation => - while matchingMacro.isEmpty && nodeOffsetMacroPairs.headOption.exists( - _._1 <= nodeOffset - ) - do - val (_, macroDefinition) = nodeOffsetMacroPairs.pop() - val macroExpansionName = ASTStringUtil.getSimpleName( - macroLocation.getExpansion.getMacroDefinition.getName - ) - val macroDefinitionName = ASTStringUtil.getSimpleName(macroDefinition.getName) - if macroExpansionName == macroDefinitionName then - matchingMacro = Option((macroDefinition, List[String]())) - } - - matchingMacro - end extractMatchingMacro - - /** Determine whether `node` is expanded from the macro expansion at `loc`. - */ - private def isExpandedFrom(node: IASTNode, loc: IASTMacroExpansionLocation): Boolean = - expandedFromMacro(node).map(_.getExpansion.getMacroDefinition).contains( - loc.getExpansion.getMacroDefinition - ) - - private def argumentTrees(arguments: List[String], ast: Ast): List[Option[Ast]] = - arguments.zipWithIndex.map { case (arg, i) => - val rootNode = argForCode(arg, ast) - rootNode.map(x => ast.subTreeCopy(x.asInstanceOf[AstNodeNew], i + 1)) - } - - private def argForCode(code: String, ast: Ast): Option[NewNode] = - val normalizedCode = code.replace(" ", "") - if normalizedCode == "" then - None - else - ast.nodes.collectFirst { - case x: ExpressionNew - if !x.isInstanceOf[NewFieldIdentifier] && x.code == normalizedCode => x - } + this: AstCreator => - /** Create an AST that represents a macro expansion as a call. The AST is rooted in a CALL node - * and contains sub trees for arguments. These are also connected to the AST via ARGUMENT - * edges. We include line number information in the CALL node that is picked up by the - * MethodStubCreator. - */ - private def createMacroCallAst( - ast: Ast, - node: IASTNode, - macroDef: IASTPreprocessorMacroDefinition, - arguments: List[String] - ): Ast = - val name = ASTStringUtil.getSimpleName(macroDef.getName) - val code = node.getRawSignature.stripSuffix(";") - val argAsts = argumentTrees(arguments, ast).map(_.getOrElse(Ast())) - - val callName = StringUtils.normalizeSpace(name) - val callFullName = StringUtils.normalizeSpace(fullName(macroDef, argAsts)) - val callNode = - NewCall() - .name(callName) - .dispatchType(DispatchTypes.INLINED) - .methodFullName(callFullName) - .code(code) - .typeFullName(typeFor(node)) - .lineNumber(line(node)) - .columnNumber(column(node)) - callAst(callNode, argAsts) - end createMacroCallAst - - /** Create a full name field that encodes line information that can be picked up by the - * MethodStubCreator in order to create a METHOD node with the correct location information. - */ - private def fullName(macroDef: IASTPreprocessorMacroDefinition, argAsts: List[Ast]) = - val name = ASTStringUtil.getSimpleName(macroDef.getName) - val filename = fileName(macroDef) - val lineNo: Integer = line(macroDef).getOrElse(-1) - val lineNoEnd: Integer = lineEnd(macroDef).getOrElse(-1) - if name != "NULL" then s"$filename:$lineNo:$lineNoEnd:$name:${argAsts.size}" else name - - /** The CDT utility method is unfortunately in a class that is marked as deprecated, however, - * this is because the CDT team would like to discourage its use but at the same time does not - * plan to remove this code. - */ - @nowarn - def nodeSignature(node: IASTNode): String = - import org.eclipse.cdt.core.dom.ast.ASTSignatureUtil.getNodeSignature - val sig = if isExpandedFromMacro(node) then - val sig = getNodeSignature(node) - if sig.isEmpty then - node.getRawSignature - else - sig - else - node.getRawSignature - shortenCode(sig) - - private def isExpandedFromMacro(node: IASTNode): Boolean = expandedFromMacro(node).nonEmpty - - private def expandedFromMacro(node: IASTNode): Option[IASTMacroExpansionLocation] = - val locations = node.getNodeLocations.toList - val locationsSorted = node match - // For binary expressions the expansion locations may occur in any order. - // We manually sort them here to ignore this. - // TODO: This may also happen with other expressions that allow for multiple sub elements. - case _: IASTBinaryExpression => - locations.sortBy(_.isInstanceOf[IASTMacroExpansionLocation]) - case _ => locations - locationsSorted match - case (head: IASTMacroExpansionLocation) :: _ => Option(head) - case _ => None + private val nodeOffsetMacroPairs: mutable.Stack[(Int, IASTPreprocessorMacroDefinition)] = + mutable.Stack.from( + cdtAst.getNodeLocations.toList + .collect { case exp: IASTMacroExpansionLocation => + (exp.asFileLocation().getNodeOffset, exp.getExpansion.getMacroDefinition) + } + .sortBy(_._1) + ) + + /** For the given node, determine if it is expanded from a macro, and if so, create a Call node to + * represent the macro invocation and attach `ast` as its child. + */ + def asChildOfMacroCall(node: IASTNode, ast: Ast): Ast = + // If a macro in a header file contained a method definition already seen in some + // source file we skipped that during the previous AST creation and returned an empty AST. + if ast.root.isEmpty && isExpandedFromMacro(node) then return ast + // We do nothing for locals only. + if ast.nodes.size == 1 && ast.root.exists(_.isInstanceOf[NewLocal]) then return ast + // Otherwise, we create the synthetic call AST. + val matchingMacro = extractMatchingMacro(node) + val macroCallAst = matchingMacro.map { case (mac, args) => + createMacroCallAst(ast, node, mac, args) + } + macroCallAst match + case Some(callAst) => + val lostLocals = ast.refEdges.collect { case AstEdge(_, dst: NewLocal) => + Ast(dst) + }.toList + val newAst = ast.subTreeCopy(ast.root.get.asInstanceOf[AstNodeNew], argIndex = 1) + // We need to wrap the copied AST as it may contain CPG nodes not being allowed + // to be connected via AST edges under a CALL. E.g., LOCALs but only if its not already a BLOCK. + val childAst = newAst.root match + case Some(_: NewBlock) => + newAst + case _ => + val b = NewBlock().argumentIndex(1).typeFullName( + registerType(Defines.voidTypeName) + ) + blockAst(b, List(newAst)) + callAst.withChildren(lostLocals).withChild(childAst) + case None => ast + end asChildOfMacroCall + + /** For the given node, determine if it is expanded from a macro, and if so, find the first + * matching (offset, macro) pair in nodeOffsetMacroPairs, removing non-matching elements from the + * start of nodeOffsetMacroPairs. Returns (Some(macroDefinition, arguments)) if a macro + * definition matches and None otherwise. + */ + private def extractMatchingMacro(node: IASTNode) + : Option[(IASTPreprocessorMacroDefinition, List[String])] = + val expansionLocations = + expandedFromMacro(node).filterNot(isExpandedFrom(node.getParent, _)) + val nodeOffset = node.getFileLocation.getNodeOffset + var matchingMacro = Option.empty[(IASTPreprocessorMacroDefinition, List[String])] + + expansionLocations.foreach { macroLocation => + while matchingMacro.isEmpty && nodeOffsetMacroPairs.headOption.exists( + _._1 <= nodeOffset + ) + do + val (_, macroDefinition) = nodeOffsetMacroPairs.pop() + val macroExpansionName = ASTStringUtil.getSimpleName( + macroLocation.getExpansion.getMacroDefinition.getName + ) + val macroDefinitionName = ASTStringUtil.getSimpleName(macroDefinition.getName) + if macroExpansionName == macroDefinitionName then + matchingMacro = Option((macroDefinition, List[String]())) + } + + matchingMacro + end extractMatchingMacro + + /** Determine whether `node` is expanded from the macro expansion at `loc`. + */ + private def isExpandedFrom(node: IASTNode, loc: IASTMacroExpansionLocation): Boolean = + expandedFromMacro(node).map(_.getExpansion.getMacroDefinition).contains( + loc.getExpansion.getMacroDefinition + ) + + private def argumentTrees(arguments: List[String], ast: Ast): List[Option[Ast]] = + arguments.zipWithIndex.map { case (arg, i) => + val rootNode = argForCode(arg, ast) + rootNode.map(x => ast.subTreeCopy(x.asInstanceOf[AstNodeNew], i + 1)) + } + + private def argForCode(code: String, ast: Ast): Option[NewNode] = + val normalizedCode = code.replace(" ", "") + if normalizedCode == "" then + None + else + ast.nodes.collectFirst { + case x: ExpressionNew + if !x.isInstanceOf[NewFieldIdentifier] && x.code == normalizedCode => x + } + + /** Create an AST that represents a macro expansion as a call. The AST is rooted in a CALL node + * and contains sub trees for arguments. These are also connected to the AST via ARGUMENT edges. + * We include line number information in the CALL node that is picked up by the + * MethodStubCreator. + */ + private def createMacroCallAst( + ast: Ast, + node: IASTNode, + macroDef: IASTPreprocessorMacroDefinition, + arguments: List[String] + ): Ast = + val name = ASTStringUtil.getSimpleName(macroDef.getName) + val code = node.getRawSignature.stripSuffix(";") + val argAsts = argumentTrees(arguments, ast).map(_.getOrElse(Ast())) + + val callName = StringUtils.normalizeSpace(name) + val callFullName = StringUtils.normalizeSpace(fullName(macroDef, argAsts)) + val callNode = + NewCall() + .name(callName) + .dispatchType(DispatchTypes.INLINED) + .methodFullName(callFullName) + .code(code) + .typeFullName(typeFor(node)) + .lineNumber(line(node)) + .columnNumber(column(node)) + callAst(callNode, argAsts) + end createMacroCallAst + + /** Create a full name field that encodes line information that can be picked up by the + * MethodStubCreator in order to create a METHOD node with the correct location information. + */ + private def fullName(macroDef: IASTPreprocessorMacroDefinition, argAsts: List[Ast]) = + val name = ASTStringUtil.getSimpleName(macroDef.getName) + val filename = fileName(macroDef) + val lineNo: Integer = line(macroDef).getOrElse(-1) + val lineNoEnd: Integer = lineEnd(macroDef).getOrElse(-1) + if name != "NULL" then s"$filename:$lineNo:$lineNoEnd:$name:${argAsts.size}" else name + + /** The CDT utility method is unfortunately in a class that is marked as deprecated, however, this + * is because the CDT team would like to discourage its use but at the same time does not plan to + * remove this code. + */ + @nowarn + def nodeSignature(node: IASTNode): String = + import org.eclipse.cdt.core.dom.ast.ASTSignatureUtil.getNodeSignature + val sig = if isExpandedFromMacro(node) then + val sig = getNodeSignature(node) + if sig.isEmpty then + node.getRawSignature + else + sig + else + node.getRawSignature + shortenCode(sig) + + private def isExpandedFromMacro(node: IASTNode): Boolean = expandedFromMacro(node).nonEmpty + + private def expandedFromMacro(node: IASTNode): Option[IASTMacroExpansionLocation] = + val locations = node.getNodeLocations.toList + val locationsSorted = node match + // For binary expressions the expansion locations may occur in any order. + // We manually sort them here to ignore this. + // TODO: This may also happen with other expressions that allow for multiple sub elements. + case _: IASTBinaryExpression => + locations.sortBy(_.isInstanceOf[IASTMacroExpansionLocation]) + case _ => locations + locationsSorted match + case (head: IASTMacroExpansionLocation) :: _ => Option(head) + case _ => None end MacroHandler diff --git a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/datastructures/CGlobal.scala b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/datastructures/CGlobal.scala index b28e4b82..1fd92706 100644 --- a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/datastructures/CGlobal.scala +++ b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/datastructures/CGlobal.scala @@ -7,7 +7,7 @@ import scala.jdk.CollectionConverters.* object CGlobal extends Global: - def typesSeen(): List[String] = - val types = usedTypes.keys().asScala.filterNot(_ == Defines.anyTypeName).toList - usedTypes.clear() - types + def typesSeen(): List[String] = + val types = usedTypes.keys().asScala.filterNot(_ == Defines.anyTypeName).toList + usedTypes.clear() + types diff --git a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/parser/CdtParser.scala b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/parser/CdtParser.scala index 1d5b5d0d..ab5415f0 100644 --- a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/parser/CdtParser.scala +++ b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/parser/CdtParser.scala @@ -19,139 +19,139 @@ import scala.jdk.CollectionConverters.* object CdtParser: - private val logger = LoggerFactory.getLogger(classOf[CdtParser]) + private val logger = LoggerFactory.getLogger(classOf[CdtParser]) - private case class ParseResult( - translationUnit: Option[IASTTranslationUnit], - preprocessorErrorCount: Int = 0, - problems: Int = 0, - failure: Option[Throwable] = None - ) + private case class ParseResult( + translationUnit: Option[IASTTranslationUnit], + preprocessorErrorCount: Int = 0, + problems: Int = 0, + failure: Option[Throwable] = None + ) - def readFileAsFileContent(path: Path): FileContent = - val lines = IOUtils.readLinesInFile(path).mkString("\n").toArray - FileContent.create(path.toString, true, lines) + def readFileAsFileContent(path: Path): FileContent = + val lines = IOUtils.readLinesInFile(path).mkString("\n").toArray + FileContent.create(path.toString, true, lines) class CdtParser(config: Config) extends ParseProblemsLogger with PreprocessorStatementsLogger: - import CdtParser.* - - private val headerFileFinder = new HeaderFileFinder(config.inputPath) - private val parserConfig = ParserConfig.fromConfig(config) - private val definedSymbols = parserConfig.definedSymbols.asJava - private val includePaths = parserConfig.userIncludePaths - private val log = new DefaultLogService - - private var stayCpp: Boolean = false; - - private val cScannerInfo: ExtendedScannerInfo = new ExtendedScannerInfo( - definedSymbols, - (includePaths ++ parserConfig.systemIncludePathsC).map(_.toString).toArray, - parserConfig.macroFiles.map(_.toString).toArray, - parserConfig.includeFiles.map(_.toString).toArray - ) - - private val cppScannerInfo: ExtendedScannerInfo = new ExtendedScannerInfo( - definedSymbols, - (includePaths ++ parserConfig.systemIncludePathsCPP).map(_.toString).toArray, - parserConfig.macroFiles.map(_.toString).toArray, - parserConfig.includeFiles.map(_.toString).toArray - ) - - // Setup indexing - var index: Option[IIndex] = Option(EmptyCIndex.INSTANCE) - if config.useProjectIndex then - try - val allProjects: Array[ICProject] = CoreModel.getDefault.getCModel.getCProjects - index = Option(CCorePlugin.getIndexManager.getIndex(allProjects)) - catch - case e: Throwable => - - // enables parsing of code behind disabled preprocessor defines: - private var opts: Int = ILanguage.OPTION_PARSE_INACTIVE_CODE - // instructs the parser to skip function and method bodies - if !config.includeFunctionBodies then opts |= ILanguage.OPTION_SKIP_FUNCTION_BODIES - // performance optimization, allows the parser not to create image-locations - if !config.includeImageLocations then opts |= ILanguage.OPTION_NO_IMAGE_LOCATIONS - - private def createParseLanguage(file: Path): ILanguage = - if FileDefaults.isCPPFile(file.toString) then - GPPLanguage.getDefault - else - GCCLanguage.getDefault - - private def createScannerInfo(file: Path): ExtendedScannerInfo = - if stayCpp || FileDefaults.isCPPFile(file.toString) then - stayCpp = true - cppScannerInfo - else cScannerInfo - - private def parseInternal(file: Path): ParseResult = - val realPath = File(file) - if realPath.isRegularFile then // handling potentially broken symlinks - try - val fileContent = readFileAsFileContent(realPath.path) - val fileContentProvider = new CustomFileContentProvider(headerFileFinder) - val lang = createParseLanguage(realPath.path) - val scannerInfo = createScannerInfo(realPath.path) - index match - case Some(x) => if x.isFullyInitialized then x.acquireReadLock() - case _ => - val translationUnit = - lang.getASTTranslationUnit( - fileContent, - scannerInfo, - fileContentProvider, - index.get, - opts, - log - ) - val problems = CPPVisitor.getProblems(translationUnit) - if parserConfig.logProblems then logProblems(problems.toList) - if parserConfig.logPreprocessor then logPreprocessorStatements(translationUnit) - ParseResult( - Option(translationUnit), - preprocessorErrorCount = translationUnit.getPreprocessorProblemsCount, - problems = problems.length - ) - catch - case u: UnsupportedClassVersionError => - logger.debug( - "c2cpg requires at least JRE-17 to run. Please check your Java Runtime Environment!", - u - ) - System.exit(1) - ParseResult( - None, - failure = Option(u) - ) // return value to make the compiler happy - case e: Throwable => - ParseResult(None, failure = Option(e)) - finally - index match - case Some(x) => x.releaseReadLock() - case _ => - else + import CdtParser.* + + private val headerFileFinder = new HeaderFileFinder(config.inputPath) + private val parserConfig = ParserConfig.fromConfig(config) + private val definedSymbols = parserConfig.definedSymbols.asJava + private val includePaths = parserConfig.userIncludePaths + private val log = new DefaultLogService + + private var stayCpp: Boolean = false; + + private val cScannerInfo: ExtendedScannerInfo = new ExtendedScannerInfo( + definedSymbols, + (includePaths ++ parserConfig.systemIncludePathsC).map(_.toString).toArray, + parserConfig.macroFiles.map(_.toString).toArray, + parserConfig.includeFiles.map(_.toString).toArray + ) + + private val cppScannerInfo: ExtendedScannerInfo = new ExtendedScannerInfo( + definedSymbols, + (includePaths ++ parserConfig.systemIncludePathsCPP).map(_.toString).toArray, + parserConfig.macroFiles.map(_.toString).toArray, + parserConfig.includeFiles.map(_.toString).toArray + ) + + // Setup indexing + var index: Option[IIndex] = Option(EmptyCIndex.INSTANCE) + if config.useProjectIndex then + try + val allProjects: Array[ICProject] = CoreModel.getDefault.getCModel.getCProjects + index = Option(CCorePlugin.getIndexManager.getIndex(allProjects)) + catch + case e: Throwable => + + // enables parsing of code behind disabled preprocessor defines: + private var opts: Int = ILanguage.OPTION_PARSE_INACTIVE_CODE + // instructs the parser to skip function and method bodies + if !config.includeFunctionBodies then opts |= ILanguage.OPTION_SKIP_FUNCTION_BODIES + // performance optimization, allows the parser not to create image-locations + if !config.includeImageLocations then opts |= ILanguage.OPTION_NO_IMAGE_LOCATIONS + + private def createParseLanguage(file: Path): ILanguage = + if FileDefaults.isCPPFile(file.toString) then + GPPLanguage.getDefault + else + GCCLanguage.getDefault + + private def createScannerInfo(file: Path): ExtendedScannerInfo = + if stayCpp || FileDefaults.isCPPFile(file.toString) then + stayCpp = true + cppScannerInfo + else cScannerInfo + + private def parseInternal(file: Path): ParseResult = + val realPath = File(file) + if realPath.isRegularFile then // handling potentially broken symlinks + try + val fileContent = readFileAsFileContent(realPath.path) + val fileContentProvider = new CustomFileContentProvider(headerFileFinder) + val lang = createParseLanguage(realPath.path) + val scannerInfo = createScannerInfo(realPath.path) + index match + case Some(x) => if x.isFullyInitialized then x.acquireReadLock() + case _ => + val translationUnit = + lang.getASTTranslationUnit( + fileContent, + scannerInfo, + fileContentProvider, + index.get, + opts, + log + ) + val problems = CPPVisitor.getProblems(translationUnit) + if parserConfig.logProblems then logProblems(problems.toList) + if parserConfig.logPreprocessor then logPreprocessorStatements(translationUnit) + ParseResult( + Option(translationUnit), + preprocessorErrorCount = translationUnit.getPreprocessorProblemsCount, + problems = problems.length + ) + catch + case u: UnsupportedClassVersionError => + logger.debug( + "c2cpg requires at least JRE-17 to run. Please check your Java Runtime Environment!", + u + ) + System.exit(1) ParseResult( None, - failure = Option(new NoSuchFileException( - s"File '$realPath' does not exist. Check for broken symlinks!" - )) - ) - end if - end parseInternal - - def preprocessorStatements(file: Path): Iterable[IASTPreprocessorStatement] = - parse(file).map(t => preprocessorStatements(t)).getOrElse(Iterable.empty) - - def parse(file: Path): Option[IASTTranslationUnit] = - val parseResult = parseInternal(file) - parseResult match - case ParseResult(Some(t), c, p, _) => - Option(t) - case ParseResult(_, _, _, maybeThrowable) => - logger.warn( - s"Failed to parse '$file': ${maybeThrowable.map(extractParseException).getOrElse("Unknown parse error!")}" - ) - None + failure = Option(u) + ) // return value to make the compiler happy + case e: Throwable => + ParseResult(None, failure = Option(e)) + finally + index match + case Some(x) => x.releaseReadLock() + case _ => + else + ParseResult( + None, + failure = Option(new NoSuchFileException( + s"File '$realPath' does not exist. Check for broken symlinks!" + )) + ) + end if + end parseInternal + + def preprocessorStatements(file: Path): Iterable[IASTPreprocessorStatement] = + parse(file).map(t => preprocessorStatements(t)).getOrElse(Iterable.empty) + + def parse(file: Path): Option[IASTTranslationUnit] = + val parseResult = parseInternal(file) + parseResult match + case ParseResult(Some(t), c, p, _) => + Option(t) + case ParseResult(_, _, _, maybeThrowable) => + logger.warn( + s"Failed to parse '$file': ${maybeThrowable.map(extractParseException).getOrElse("Unknown parse error!")}" + ) + None end CdtParser diff --git a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/parser/CustomFileContentProvider.scala b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/parser/CustomFileContentProvider.scala index db3bcb18..092ba042 100644 --- a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/parser/CustomFileContentProvider.scala +++ b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/parser/CustomFileContentProvider.scala @@ -13,34 +13,34 @@ import java.nio.file.Paths class CustomFileContentProvider(headerFileFinder: HeaderFileFinder) extends InternalFileContentProvider: - private val logger = LoggerFactory.getLogger(classOf[CustomFileContentProvider]) + private val logger = LoggerFactory.getLogger(classOf[CustomFileContentProvider]) - private def loadContent(path: String): InternalFileContent = - val maybeFullPath = if !getInclusionExists(path) then - headerFileFinder.find(path) - else - Option(path) - maybeFullPath - .map { foundPath => - logger.debug(s"Loading header file '$foundPath'") - CdtParser.readFileAsFileContent(Paths.get(foundPath)).asInstanceOf[ - InternalFileContent - ] - } - .getOrElse { - logger.debug(s"Cannot find header file for '$path'") - null - } + private def loadContent(path: String): InternalFileContent = + val maybeFullPath = if !getInclusionExists(path) then + headerFileFinder.find(path) + else + Option(path) + maybeFullPath + .map { foundPath => + logger.debug(s"Loading header file '$foundPath'") + CdtParser.readFileAsFileContent(Paths.get(foundPath)).asInstanceOf[ + InternalFileContent + ] + } + .getOrElse { + logger.debug(s"Cannot find header file for '$path'") + null + } - override def getContentForInclusion( - path: String, - macroDictionary: IMacroDictionary - ): InternalFileContent = - loadContent(path) + override def getContentForInclusion( + path: String, + macroDictionary: IMacroDictionary + ): InternalFileContent = + loadContent(path) - override def getContentForInclusion( - ifl: IIndexFileLocation, - astPath: String - ): InternalFileContent = - loadContent(astPath) + override def getContentForInclusion( + ifl: IIndexFileLocation, + astPath: String + ): InternalFileContent = + loadContent(astPath) end CustomFileContentProvider diff --git a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/parser/DefaultDefines.scala b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/parser/DefaultDefines.scala index c484a5ea..e47d7e9e 100644 --- a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/parser/DefaultDefines.scala +++ b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/parser/DefaultDefines.scala @@ -1,35 +1,35 @@ package io.appthreat.c2cpg.parser object DefaultDefines: - val DEFAULT_CALL_CONVENTIONS: Map[String, String] = Map( - "__fastcall" -> "__attribute((fastcall))", - "__cdecl" -> "__attribute((cdecl))", - "__pascal" -> "__attribute((pascal))", - "__vectorcall" -> "__attribute((vectorcall))", - "__clrcall" -> "__attribute((clrcall))", - "__stdcall" -> "__attribute((stdcall))", - "__thiscall" -> "__attribute((thiscall))", - "__declspec" -> "__attribute((declspec))", - "__restrict" -> "__attribute((restrict))", - "__sptr" -> "__attribute((sptr))", - "__uptr" -> "__attribute((uptr))", - "__syscall" -> "__attribute((syscall))", - "__oldcall" -> "__attribute((oldcall))", - "__unaligned" -> "__attribute((unaligned))", - "__w64" -> "__attribute((w64))", - "__asm" -> "__attribute((asm))", - "__based" -> "__attribute((based))", - "__interface" -> "__attribute((interface))", - "__event" -> "__attribute((event))", - "__hook" -> "__attribute((hook))", - "__unhook" -> "__attribute((unhook))", - "__raise" -> "__attribute((raise))", - "__try" -> "__attribute((try))", - "__except" -> "__attribute((except))", - "__finally" -> "__attribute((finally))", - "__m128" -> "__attribute((m128))", - "__m128d" -> "__attribute((m128d))", - "__m128i" -> "__attribute((m128i))", - "__m64" -> "__attribute((m64))" - ) + val DEFAULT_CALL_CONVENTIONS: Map[String, String] = Map( + "__fastcall" -> "__attribute((fastcall))", + "__cdecl" -> "__attribute((cdecl))", + "__pascal" -> "__attribute((pascal))", + "__vectorcall" -> "__attribute((vectorcall))", + "__clrcall" -> "__attribute((clrcall))", + "__stdcall" -> "__attribute((stdcall))", + "__thiscall" -> "__attribute((thiscall))", + "__declspec" -> "__attribute((declspec))", + "__restrict" -> "__attribute((restrict))", + "__sptr" -> "__attribute((sptr))", + "__uptr" -> "__attribute((uptr))", + "__syscall" -> "__attribute((syscall))", + "__oldcall" -> "__attribute((oldcall))", + "__unaligned" -> "__attribute((unaligned))", + "__w64" -> "__attribute((w64))", + "__asm" -> "__attribute((asm))", + "__based" -> "__attribute((based))", + "__interface" -> "__attribute((interface))", + "__event" -> "__attribute((event))", + "__hook" -> "__attribute((hook))", + "__unhook" -> "__attribute((unhook))", + "__raise" -> "__attribute((raise))", + "__try" -> "__attribute((try))", + "__except" -> "__attribute((except))", + "__finally" -> "__attribute((finally))", + "__m128" -> "__attribute((m128))", + "__m128d" -> "__attribute((m128d))", + "__m128i" -> "__attribute((m128i))", + "__m64" -> "__attribute((m64))" + ) end DefaultDefines diff --git a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/parser/FileDefaults.scala b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/parser/FileDefaults.scala index bb2766a8..874e05a3 100644 --- a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/parser/FileDefaults.scala +++ b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/parser/FileDefaults.scala @@ -2,24 +2,25 @@ package io.appthreat.c2cpg.parser object FileDefaults: - val C_EXT: String = ".c" - val CPP_EXT: String = ".cpp" + val C_EXT: String = ".c" + val CPP_EXT: String = ".cpp" - private val CC_EXT = ".cc" - private val C_HEADER_EXT = ".h" - private val CPP_HEADER_EXT = ".hpp" - private val OTHER_HEADER_EXT = ".hh" + private val CC_EXT = ".cc" + private val C_HEADER_EXT = ".h" + private val CPP_HEADER_EXT = ".hpp" + private val OTHER_HEADER_EXT = ".hh" + private val PREPROCESSED_EXT = ".i" - val SOURCE_FILE_EXTENSIONS: Set[String] = Set(C_EXT, CC_EXT, CPP_EXT) + val SOURCE_FILE_EXTENSIONS: Set[String] = Set(C_EXT, CC_EXT, CPP_EXT) - val HEADER_FILE_EXTENSIONS: Set[String] = - Set(C_HEADER_EXT, CPP_HEADER_EXT, OTHER_HEADER_EXT, ".h.in", ".tmh") + val HEADER_FILE_EXTENSIONS: Set[String] = + Set(C_HEADER_EXT, CPP_HEADER_EXT, OTHER_HEADER_EXT, PREPROCESSED_EXT, ".h.in", ".tmh") - private val CPP_FILE_EXTENSIONS = Set(CC_EXT, CPP_EXT, CPP_HEADER_EXT, ".ccm", ".cxxm", ".c++m") + private val CPP_FILE_EXTENSIONS = Set(CC_EXT, CPP_EXT, CPP_HEADER_EXT, ".ccm", ".cxxm", ".c++m") - def isHeaderFile(filePath: String): Boolean = - HEADER_FILE_EXTENSIONS.exists(filePath.endsWith) + def isHeaderFile(filePath: String): Boolean = + HEADER_FILE_EXTENSIONS.exists(filePath.endsWith) - def isCPPFile(filePath: String): Boolean = - CPP_FILE_EXTENSIONS.exists(filePath.endsWith) + def isCPPFile(filePath: String): Boolean = + CPP_FILE_EXTENSIONS.exists(filePath.endsWith) end FileDefaults diff --git a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/parser/HeaderFileFinder.scala b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/parser/HeaderFileFinder.scala index 5572d3aa..fc5cce99 100644 --- a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/parser/HeaderFileFinder.scala +++ b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/parser/HeaderFileFinder.scala @@ -8,19 +8,19 @@ import java.nio.file.Path class HeaderFileFinder(root: String): - private val nameToPathMap: Map[String, List[Path]] = SourceFiles - .determine(root, FileDefaults.HEADER_FILE_EXTENSIONS) - .map { p => - val file = File(p) - (file.name, file.path) - } - .groupBy(_._1) - .map(x => (x._1, x._2.map(_._2))) + private val nameToPathMap: Map[String, List[Path]] = SourceFiles + .determine(root, FileDefaults.HEADER_FILE_EXTENSIONS) + .map { p => + val file = File(p) + (file.name, file.path) + } + .groupBy(_._1) + .map(x => (x._1, x._2.map(_._2))) - /** Given an unresolved header file, given as a non-existing absolute path, determine whether a - * header file with the same name can be found anywhere in the code base. - */ - def find(path: String): Option[String] = File(path).nameOption.flatMap { name => - val matches = nameToPathMap.getOrElse(name, List()) - matches.map(_.toString).sortBy(x => Levenshtein.distance(x, path)).headOption - } + /** Given an unresolved header file, given as a non-existing absolute path, determine whether a + * header file with the same name can be found anywhere in the code base. + */ + def find(path: String): Option[String] = File(path).nameOption.flatMap { name => + val matches = nameToPathMap.getOrElse(name, List()) + matches.map(_.toString).sortBy(x => Levenshtein.distance(x, path)).headOption + } diff --git a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/parser/ParseProblemsLogger.scala b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/parser/ParseProblemsLogger.scala index b8b07d42..4c989d5a 100644 --- a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/parser/ParseProblemsLogger.scala +++ b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/parser/ParseProblemsLogger.scala @@ -5,33 +5,33 @@ import org.slf4j.LoggerFactory trait ParseProblemsLogger: - this: CdtParser => + this: CdtParser => - private val logger = LoggerFactory.getLogger(classOf[ParseProblemsLogger]) + private val logger = LoggerFactory.getLogger(classOf[ParseProblemsLogger]) - private def logProblemNode(node: IASTProblem): Unit = - val text = s"""Parse problem '${node.getClass.getSimpleName}' occurred! + private def logProblemNode(node: IASTProblem): Unit = + val text = s"""Parse problem '${node.getClass.getSimpleName}' occurred! | Code: '${node.getRawSignature}' | File: '${node.getFileLocation.getFileName}' | Line: ${node.getFileLocation.getStartingLineNumber} | """.stripMargin - logger.debug(text) + logger.debug(text) - protected def logProblems(problems: List[IASTProblem]): Unit = - problems.foreach(logProblemNode) + protected def logProblems(problems: List[IASTProblem]): Unit = + problems.foreach(logProblemNode) - /** The exception message might be null for parse failures due to ambiguous nodes that can't be - * resolved successfully. We extract better log messages for this case here. - */ - protected def extractParseException(exception: Throwable): String = - Option(exception.getMessage) match - case Some(message) => message - case None => - exception.getStackTrace - .collectFirst { - case stackTraceElement: StackTraceElement - if stackTraceElement.getClassName.endsWith("ASTAmbiguousNode") => - "Could not resolve ambiguous node!" - } - .getOrElse(exception.getStackTrace.mkString(System.lineSeparator())) + /** The exception message might be null for parse failures due to ambiguous nodes that can't be + * resolved successfully. We extract better log messages for this case here. + */ + protected def extractParseException(exception: Throwable): String = + Option(exception.getMessage) match + case Some(message) => message + case None => + exception.getStackTrace + .collectFirst { + case stackTraceElement: StackTraceElement + if stackTraceElement.getClassName.endsWith("ASTAmbiguousNode") => + "Could not resolve ambiguous node!" + } + .getOrElse(exception.getStackTrace.mkString(System.lineSeparator())) end ParseProblemsLogger diff --git a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/parser/ParserConfig.scala b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/parser/ParserConfig.scala index 732a38bf..05f6f56d 100644 --- a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/parser/ParserConfig.scala +++ b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/parser/ParserConfig.scala @@ -7,33 +7,33 @@ import java.nio.file.{Path, Paths} object ParserConfig: - def empty: ParserConfig = - ParserConfig( - Set.empty, - Set.empty, - Set.empty, - Set.empty, - Map.empty, - Set.empty, - logProblems = false, - logPreprocessor = false - ) + def empty: ParserConfig = + ParserConfig( + Set.empty, + Set.empty, + Set.empty, + Set.empty, + Map.empty, + Set.empty, + logProblems = false, + logPreprocessor = false + ) - def fromConfig(config: Config): ParserConfig = ParserConfig( - config.includeFiles.map(Paths.get(_).toAbsolutePath), - config.includePaths.map(Paths.get(_).toAbsolutePath), - IncludeAutoDiscovery.discoverIncludePathsC(config), - IncludeAutoDiscovery.discoverIncludePathsCPP(config), - config.defines.map { - case define if define.contains("=") => - val s = define.split("=") - s.head -> s(1) - case define => define -> "true" - }.toMap ++ DefaultDefines.DEFAULT_CALL_CONVENTIONS, - config.macroFiles.map(Paths.get(_).toAbsolutePath), - config.logProblems, - config.logPreprocessor - ) + def fromConfig(config: Config): ParserConfig = ParserConfig( + config.includeFiles.map(Paths.get(_).toAbsolutePath), + config.includePaths.map(Paths.get(_).toAbsolutePath), + IncludeAutoDiscovery.discoverIncludePathsC(config), + IncludeAutoDiscovery.discoverIncludePathsCPP(config), + config.defines.map { + case define if define.contains("=") => + val s = define.split("=") + s.head -> s(1) + case define => define -> "true" + }.toMap ++ DefaultDefines.DEFAULT_CALL_CONVENTIONS, + config.macroFiles.map(Paths.get(_).toAbsolutePath), + config.logProblems, + config.logPreprocessor + ) end ParserConfig case class ParserConfig( diff --git a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/parser/PreprocessorStatementsLogger.scala b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/parser/PreprocessorStatementsLogger.scala index 2acc1bc6..1d29ed42 100644 --- a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/parser/PreprocessorStatementsLogger.scala +++ b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/parser/PreprocessorStatementsLogger.scala @@ -11,33 +11,33 @@ import org.slf4j.LoggerFactory trait PreprocessorStatementsLogger: - this: CdtParser => + this: CdtParser => - private val logger = LoggerFactory.getLogger(classOf[PreprocessorStatementsLogger]) + private val logger = LoggerFactory.getLogger(classOf[PreprocessorStatementsLogger]) - private def logPreprocessorStatement(node: IASTPreprocessorStatement): Unit = - val text = s"""Preprocessor statement '${node.getClass.getSimpleName}' found! + private def logPreprocessorStatement(node: IASTPreprocessorStatement): Unit = + val text = s"""Preprocessor statement '${node.getClass.getSimpleName}' found! | Code: '${node.getRawSignature}' | File: '${node.getFileLocation.getFileName}' | Line: ${node.getFileLocation.getStartingLineNumber}""".stripMargin - val additionalInfo = node match - case s: IASTPreprocessorFunctionStyleMacroDefinition => - s""" + val additionalInfo = node match + case s: IASTPreprocessorFunctionStyleMacroDefinition => + s""" | Parameter: ${s.getParameters.map(_.getRawSignature).mkString(", ")} | Expansion: ${s.getExpansion}""".stripMargin - case s: IASTPreprocessorIfStatement => - s""" + case s: IASTPreprocessorIfStatement => + s""" | Defined: ${s.taken()}""".stripMargin - case s: IASTPreprocessorIfdefStatement => - s""" + case s: IASTPreprocessorIfdefStatement => + s""" | Defined: ${s.taken()}""".stripMargin - case _ => "" - logger.debug(s"$text$additionalInfo") + case _ => "" + logger.debug(s"$text$additionalInfo") - protected def preprocessorStatements(translationUnit: IASTTranslationUnit) - : Iterable[IASTPreprocessorStatement] = - translationUnit.getAllPreprocessorStatements + protected def preprocessorStatements(translationUnit: IASTTranslationUnit) + : Iterable[IASTPreprocessorStatement] = + translationUnit.getAllPreprocessorStatements - protected def logPreprocessorStatements(translationUnit: IASTTranslationUnit): Unit = - preprocessorStatements(translationUnit).foreach(logPreprocessorStatement) + protected def logPreprocessorStatements(translationUnit: IASTTranslationUnit): Unit = + preprocessorStatements(translationUnit).foreach(logPreprocessorStatement) end PreprocessorStatementsLogger diff --git a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/passes/AstCreationPass.scala b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/passes/AstCreationPass.scala index 9ad26848..d3bc1e28 100644 --- a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/passes/AstCreationPass.scala +++ b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/passes/AstCreationPass.scala @@ -16,39 +16,39 @@ import scala.util.matching.Regex class AstCreationPass(cpg: Cpg, config: Config, report: Report = new Report()) extends ConcurrentWriterCpgPass[String](cpg): - private val file2OffsetTable: ConcurrentHashMap[String, Array[Int]] = new ConcurrentHashMap() - private val parser: CdtParser = new CdtParser(config) + private val file2OffsetTable: ConcurrentHashMap[String, Array[Int]] = new ConcurrentHashMap() + private val parser: CdtParser = new CdtParser(config) - private val EscapedFileSeparator = Pattern.quote(java.io.File.separator) - private val DefaultIgnoredFolders: List[Regex] = List( - "\\..*".r, - s"(.*[$EscapedFileSeparator])?tests?[$EscapedFileSeparator].*".r, - s"(.*[$EscapedFileSeparator])?CMakeFiles[$EscapedFileSeparator].*".r - ) + private val EscapedFileSeparator = Pattern.quote(java.io.File.separator) + private val DefaultIgnoredFolders: List[Regex] = List( + "\\..*".r, + s"(.*[$EscapedFileSeparator])?tests?[$EscapedFileSeparator].*".r, + s"(.*[$EscapedFileSeparator])?CMakeFiles[$EscapedFileSeparator].*".r + ) - override def generateParts(): Array[String] = - SourceFiles - .determine( - config.inputPath, - FileDefaults.SOURCE_FILE_EXTENSIONS ++ FileDefaults.HEADER_FILE_EXTENSIONS, - config.withDefaultIgnoredFilesRegex(DefaultIgnoredFolders) - ) - .sortWith(_.compareToIgnoreCase(_) > 0) - .toArray + override def generateParts(): Array[String] = + SourceFiles + .determine( + config.inputPath, + FileDefaults.SOURCE_FILE_EXTENSIONS ++ FileDefaults.HEADER_FILE_EXTENSIONS, + config.withDefaultIgnoredFilesRegex(DefaultIgnoredFolders) + ) + .sortWith(_.compareToIgnoreCase(_) > 0) + .toArray - override def runOnPart(diffGraph: DiffGraphBuilder, filename: String): Unit = - val path = Paths.get(filename).toAbsolutePath - val relPath = SourceFiles.toRelativePath(path.toString, config.inputPath) - try - val parseResult = parser.parse(path) - parseResult match - case Some(translationUnit) => - val localDiff = - new AstCreator(relPath, config, translationUnit, file2OffsetTable)( - config.schemaValidation - ).createAst() - diffGraph.absorb(localDiff) - case None => - catch - case e: Throwable => + override def runOnPart(diffGraph: DiffGraphBuilder, filename: String): Unit = + val path = Paths.get(filename).toAbsolutePath + val relPath = SourceFiles.toRelativePath(path.toString, config.inputPath) + try + val parseResult = parser.parse(path) + parseResult match + case Some(translationUnit) => + val localDiff = + new AstCreator(relPath, config, translationUnit, file2OffsetTable)( + config.schemaValidation + ).createAst() + diffGraph.absorb(localDiff) + case None => + catch + case e: Throwable => end AstCreationPass diff --git a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/passes/ConfigFileCreationPass.scala b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/passes/ConfigFileCreationPass.scala index f8bfdef3..b6e7d9d6 100644 --- a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/passes/ConfigFileCreationPass.scala +++ b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/passes/ConfigFileCreationPass.scala @@ -6,22 +6,22 @@ import io.shiftleft.codepropertygraph.Cpg class ConfigFileCreationPass(cpg: Cpg) extends XConfigFileCreationPass(cpg): - override val configFileFilters: List[File => Boolean] = List( - // TOML files - extensionFilter(".toml"), - // INI files - extensionFilter(".ini"), - // YAML files - extensionFilter(".yaml"), - extensionFilter(".lock"), - pathEndFilter("conanfile.txt"), - extensionFilter(".cmake"), - extensionFilter(".build"), - pathEndFilter("CMakeLists.txt"), - pathEndFilter("bom.json"), - pathEndFilter(".cdx.json"), - pathEndFilter("chennai.json"), - pathEndFilter("setup.cfg"), - pathEndFilter("setup.py") - ) + override val configFileFilters: List[File => Boolean] = List( + // TOML files + extensionFilter(".toml"), + // INI files + extensionFilter(".ini"), + // YAML files + extensionFilter(".yaml"), + extensionFilter(".lock"), + pathEndFilter("conanfile.txt"), + extensionFilter(".cmake"), + extensionFilter(".build"), + pathEndFilter("CMakeLists.txt"), + pathEndFilter("bom.json"), + pathEndFilter(".cdx.json"), + pathEndFilter("chennai.json"), + pathEndFilter("setup.cfg"), + pathEndFilter("setup.py") + ) end ConfigFileCreationPass diff --git a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/passes/PreprocessorPass.scala b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/passes/PreprocessorPass.scala index e66ef80b..64fd76e7 100644 --- a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/passes/PreprocessorPass.scala +++ b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/passes/PreprocessorPass.scala @@ -15,23 +15,23 @@ import scala.collection.parallel.immutable.ParIterable class PreprocessorPass(config: Config): - private val parser = new CdtParser(config) + private val parser = new CdtParser(config) - def run(): ParIterable[String] = - SourceFiles.determine(config.inputPath, FileDefaults.SOURCE_FILE_EXTENSIONS).par.flatMap( - runOnPart - ) + def run(): ParIterable[String] = + SourceFiles.determine(config.inputPath, FileDefaults.SOURCE_FILE_EXTENSIONS).par.flatMap( + runOnPart + ) - private def preprocessorStatement2String(stmt: IASTPreprocessorStatement): Option[String] = - stmt match - case s: IASTPreprocessorIfStatement => - Option(s"${s.getCondition.mkString}${if s.taken() then "=true" else ""}") - case s: IASTPreprocessorIfdefStatement => - Option(s"${s.getCondition.mkString}${if s.taken() then "=true" else ""}") - case _ => None + private def preprocessorStatement2String(stmt: IASTPreprocessorStatement): Option[String] = + stmt match + case s: IASTPreprocessorIfStatement => + Option(s"${s.getCondition.mkString}${if s.taken() then "=true" else ""}") + case s: IASTPreprocessorIfdefStatement => + Option(s"${s.getCondition.mkString}${if s.taken() then "=true" else ""}") + case _ => None - private def runOnPart(filename: String): Iterable[String] = - parser.preprocessorStatements(Paths.get(filename)).flatMap( - preprocessorStatement2String - ).toSet + private def runOnPart(filename: String): Iterable[String] = + parser.preprocessorStatements(Paths.get(filename)).flatMap( + preprocessorStatement2String + ).toSet end PreprocessorPass diff --git a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/passes/TypeDeclNodePass.scala b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/passes/TypeDeclNodePass.scala index 6a50d8f7..404570d1 100644 --- a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/passes/TypeDeclNodePass.scala +++ b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/passes/TypeDeclNodePass.scala @@ -14,53 +14,53 @@ import io.shiftleft.semanticcpg.language.types.structure.NamespaceTraversal class TypeDeclNodePass(cpg: Cpg)(implicit withSchemaValidation: ValidationMode) extends CpgPass(cpg): - private val filename: String = "" - private val globalName: String = NamespaceTraversal.globalNamespaceName - private val fullName: String = MetaDataPass.getGlobalNamespaceBlockFullName(Option(filename)) + private val filename: String = "" + private val globalName: String = NamespaceTraversal.globalNamespaceName + private val fullName: String = MetaDataPass.getGlobalNamespaceBlockFullName(Option(filename)) - private val typeDeclFullNames: Set[String] = cpg.typeDecl.fullName.toSetImmutable + private val typeDeclFullNames: Set[String] = cpg.typeDecl.fullName.toSetImmutable - private def createGlobalAst(): Ast = - val includesFile = NewFile().name(filename) - val namespaceBlock = NewNamespaceBlock() + private def createGlobalAst(): Ast = + val includesFile = NewFile().name(filename) + val namespaceBlock = NewNamespaceBlock() + .name(globalName) + .fullName(fullName) + .filename(filename) + val fakeGlobalIncludesMethod = + NewMethod() .name(globalName) + .code(globalName) .fullName(fullName) .filename(filename) - val fakeGlobalIncludesMethod = - NewMethod() - .name(globalName) - .code(globalName) - .fullName(fullName) - .filename(filename) - .lineNumber(1) - .astParentType(NodeTypes.NAMESPACE_BLOCK) - .astParentFullName(fullName) - val blockNode = NewBlock().typeFullName(Defines.anyTypeName) - val methodReturn = newMethodReturnNode(Defines.anyTypeName, line = None, column = None) - Ast(includesFile).withChild( - Ast(namespaceBlock) - .withChild( - Ast(fakeGlobalIncludesMethod).withChild(Ast(blockNode)).withChild(Ast(methodReturn)) - ) - ) - end createGlobalAst + .lineNumber(1) + .astParentType(NodeTypes.NAMESPACE_BLOCK) + .astParentFullName(fullName) + val blockNode = NewBlock().typeFullName(Defines.anyTypeName) + val methodReturn = newMethodReturnNode(Defines.anyTypeName, line = None, column = None) + Ast(includesFile).withChild( + Ast(namespaceBlock) + .withChild( + Ast(fakeGlobalIncludesMethod).withChild(Ast(blockNode)).withChild(Ast(methodReturn)) + ) + ) + end createGlobalAst - private def typeNeedsTypeDeclStub(t: Type): Boolean = - !typeDeclFullNames.contains(t.typeDeclFullName) + private def typeNeedsTypeDeclStub(t: Type): Boolean = + !typeDeclFullNames.contains(t.typeDeclFullName) - override def run(dstGraph: DiffGraphBuilder): Unit = - var hadMissingTypeDecl = false - cpg.typ.filter(typeNeedsTypeDeclStub).foreach { t => - val newTypeDecl = NewTypeDecl() - .name(t.name) - .fullName(t.typeDeclFullName) - .code(t.name) - .isExternal(true) - .filename(filename) - .astParentType(NodeTypes.NAMESPACE_BLOCK) - .astParentFullName(fullName) - dstGraph.addNode(newTypeDecl) - hadMissingTypeDecl = true - } - if hadMissingTypeDecl then Ast.storeInDiffGraph(createGlobalAst(), dstGraph) + override def run(dstGraph: DiffGraphBuilder): Unit = + var hadMissingTypeDecl = false + cpg.typ.filter(typeNeedsTypeDeclStub).foreach { t => + val newTypeDecl = NewTypeDecl() + .name(t.name) + .fullName(t.typeDeclFullName) + .code(t.name) + .isExternal(true) + .filename(filename) + .astParentType(NodeTypes.NAMESPACE_BLOCK) + .astParentFullName(fullName) + dstGraph.addNode(newTypeDecl) + hadMissingTypeDecl = true + } + if hadMissingTypeDecl then Ast.storeInDiffGraph(createGlobalAst(), dstGraph) end TypeDeclNodePass diff --git a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/utils/ExternalCommand.scala b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/utils/ExternalCommand.scala index 4fffc5fa..731db5e6 100644 --- a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/utils/ExternalCommand.scala +++ b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/utils/ExternalCommand.scala @@ -6,23 +6,23 @@ import scala.util.{Failure, Success, Try} object ExternalCommand: - private val IS_WIN: Boolean = - scala.util.Properties.isWin + private val IS_WIN: Boolean = + scala.util.Properties.isWin - private val shellPrefix: Seq[String] = - if IS_WIN then "cmd" :: "/c" :: Nil else "sh" :: "-c" :: Nil + private val shellPrefix: Seq[String] = + if IS_WIN then "cmd" :: "/c" :: Nil else "sh" :: "-c" :: Nil - def run(command: String): Try[Seq[String]] = - val result = mutable.ArrayBuffer.empty[String] - val lineHandler: String => Unit = result.addOne - Process(shellPrefix :+ command).!(ProcessLogger(lineHandler, lineHandler)) match - case 0 => - Success(result.toSeq) - case 1 - if IS_WIN && - command != IncludeAutoDiscovery.GCC_VERSION_COMMAND && - IncludeAutoDiscovery.gccAvailable() => - Success(result.toSeq) - case _ => - Failure(new RuntimeException(result.mkString(System.lineSeparator()))) + def run(command: String): Try[Seq[String]] = + val result = mutable.ArrayBuffer.empty[String] + val lineHandler: String => Unit = result.addOne + Process(shellPrefix :+ command).!(ProcessLogger(lineHandler, lineHandler)) match + case 0 => + Success(result.toSeq) + case 1 + if IS_WIN && + command != IncludeAutoDiscovery.GCC_VERSION_COMMAND && + IncludeAutoDiscovery.gccAvailable() => + Success(result.toSeq) + case _ => + Failure(new RuntimeException(result.mkString(System.lineSeparator()))) end ExternalCommand diff --git a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/utils/IncludeAutoDiscovery.scala b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/utils/IncludeAutoDiscovery.scala index 4856d6c0..e921cdf7 100644 --- a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/utils/IncludeAutoDiscovery.scala +++ b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/utils/IncludeAutoDiscovery.scala @@ -9,90 +9,90 @@ import scala.util.Success object IncludeAutoDiscovery: - private val logger = LoggerFactory.getLogger(IncludeAutoDiscovery.getClass) - - private val IS_WIN = scala.util.Properties.isWin - - val GCC_VERSION_COMMAND = "gcc --version" - - private val CPP_INCLUDE_COMMAND = - if IS_WIN then "gcc -xc++ -E -v . -o nul" else "gcc -xc++ -E -v /dev/null -o /dev/null" - - private val C_INCLUDE_COMMAND = - if IS_WIN then "gcc -xc -E -v . -o nul" else "gcc -xc -E -v /dev/null -o /dev/null" - - // Only check once - private var isGccAvailable: Option[Boolean] = None - - // Only discover them once - private var systemIncludePathsC: Set[Path] = Set.empty - private var systemIncludePathsCPP: Set[Path] = Set.empty - - private def checkForGcc(): Boolean = - logger.debug("Checking gcc ...") - ExternalCommand.run(GCC_VERSION_COMMAND) match - case Success(result) => - logger.debug(s"GCC is available: ${result.mkString(System.lineSeparator())}") - true - case _ => - logger.warn( - "GCC is not installed. Discovery of system include paths will not be available." - ) - false - - def gccAvailable(): Boolean = isGccAvailable match - case Some(value) => - value - case None => - isGccAvailable = Option(checkForGcc()) - isGccAvailable.get - - private def extractPaths(output: Seq[String]): Set[Path] = - val startIndex = - output.indexWhere(_.contains("#include")) + 2 - val endIndex = - if IS_WIN then output.indexWhere(_.startsWith("End of search list.")) - 1 - else output.indexWhere(_.startsWith("COMPILER_PATH")) - 1 - output.slice(startIndex, endIndex).map(p => Paths.get(p.trim).toRealPath()).toSet - - private def discoverPaths(command: String): Set[Path] = ExternalCommand.run(command) match - case Success(output) => extractPaths(output) - case Failure(exception) => - logger.warn( - s"Unable to discover system include paths. Running '$command' failed.", - exception - ) - Set.empty - - def discoverIncludePathsC(config: Config): Set[Path] = - if config.includePathsAutoDiscovery && systemIncludePathsC.nonEmpty then - systemIncludePathsC - else if config.includePathsAutoDiscovery && systemIncludePathsC.isEmpty && gccAvailable() - then - val includePathsC = discoverPaths(C_INCLUDE_COMMAND) - if includePathsC.nonEmpty then - logger.debug( - s"Using the following C system include paths:${includePathsC - .mkString(s"${System.lineSeparator()}- ", s"${System.lineSeparator()}- ", System.lineSeparator())}" - ) - systemIncludePathsC = includePathsC - includePathsC - else - Set.empty - - def discoverIncludePathsCPP(config: Config): Set[Path] = - if config.includePathsAutoDiscovery && systemIncludePathsCPP.nonEmpty then - systemIncludePathsCPP - else if config.includePathsAutoDiscovery && systemIncludePathsCPP.isEmpty && gccAvailable() - then - val includePathsCPP = discoverPaths(CPP_INCLUDE_COMMAND) - if includePathsCPP.nonEmpty then - logger.debug( - s"Using the following CPP system include paths:${includePathsCPP - .mkString(s"${System.lineSeparator()}- ", s"${System.lineSeparator()}- ", System.lineSeparator())}" - ) - systemIncludePathsCPP = includePathsCPP - includePathsCPP - else - Set.empty + private val logger = LoggerFactory.getLogger(IncludeAutoDiscovery.getClass) + + private val IS_WIN = scala.util.Properties.isWin + + val GCC_VERSION_COMMAND = "gcc --version" + + private val CPP_INCLUDE_COMMAND = + if IS_WIN then "gcc -xc++ -E -v . -o nul" else "gcc -xc++ -E -v /dev/null -o /dev/null" + + private val C_INCLUDE_COMMAND = + if IS_WIN then "gcc -xc -E -v . -o nul" else "gcc -xc -E -v /dev/null -o /dev/null" + + // Only check once + private var isGccAvailable: Option[Boolean] = None + + // Only discover them once + private var systemIncludePathsC: Set[Path] = Set.empty + private var systemIncludePathsCPP: Set[Path] = Set.empty + + private def checkForGcc(): Boolean = + logger.debug("Checking gcc ...") + ExternalCommand.run(GCC_VERSION_COMMAND) match + case Success(result) => + logger.debug(s"GCC is available: ${result.mkString(System.lineSeparator())}") + true + case _ => + logger.warn( + "GCC is not installed. Discovery of system include paths will not be available." + ) + false + + def gccAvailable(): Boolean = isGccAvailable match + case Some(value) => + value + case None => + isGccAvailable = Option(checkForGcc()) + isGccAvailable.get + + private def extractPaths(output: Seq[String]): Set[Path] = + val startIndex = + output.indexWhere(_.contains("#include")) + 2 + val endIndex = + if IS_WIN then output.indexWhere(_.startsWith("End of search list.")) - 1 + else output.indexWhere(_.startsWith("COMPILER_PATH")) - 1 + output.slice(startIndex, endIndex).map(p => Paths.get(p.trim).toRealPath()).toSet + + private def discoverPaths(command: String): Set[Path] = ExternalCommand.run(command) match + case Success(output) => extractPaths(output) + case Failure(exception) => + logger.warn( + s"Unable to discover system include paths. Running '$command' failed.", + exception + ) + Set.empty + + def discoverIncludePathsC(config: Config): Set[Path] = + if config.includePathsAutoDiscovery && systemIncludePathsC.nonEmpty then + systemIncludePathsC + else if config.includePathsAutoDiscovery && systemIncludePathsC.isEmpty && gccAvailable() + then + val includePathsC = discoverPaths(C_INCLUDE_COMMAND) + if includePathsC.nonEmpty then + logger.debug( + s"Using the following C system include paths:${includePathsC + .mkString(s"${System.lineSeparator()}- ", s"${System.lineSeparator()}- ", System.lineSeparator())}" + ) + systemIncludePathsC = includePathsC + includePathsC + else + Set.empty + + def discoverIncludePathsCPP(config: Config): Set[Path] = + if config.includePathsAutoDiscovery && systemIncludePathsCPP.nonEmpty then + systemIncludePathsCPP + else if config.includePathsAutoDiscovery && systemIncludePathsCPP.isEmpty && gccAvailable() + then + val includePathsCPP = discoverPaths(CPP_INCLUDE_COMMAND) + if includePathsCPP.nonEmpty then + logger.debug( + s"Using the following CPP system include paths:${includePathsCPP + .mkString(s"${System.lineSeparator()}- ", s"${System.lineSeparator()}- ", System.lineSeparator())}" + ) + systemIncludePathsCPP = includePathsCPP + includePathsCPP + else + Set.empty end IncludeAutoDiscovery diff --git a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/utils/Report.scala b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/utils/Report.scala index d628c6cc..c522ec1f 100644 --- a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/utils/Report.scala +++ b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/utils/Report.scala @@ -6,80 +6,80 @@ import scala.collection.concurrent.TrieMap object Report: - private val logger = LoggerFactory.getLogger(Report.getClass) + private val logger = LoggerFactory.getLogger(Report.getClass) - private type FileName = String + private type FileName = String - private type Reports = TrieMap[FileName, ReportEntry] + private type Reports = TrieMap[FileName, ReportEntry] - private case class ReportEntry(loc: Int, parsed: Boolean, cpgGen: Boolean, duration: Long): - def toSeq: Seq[String] = - val lines = loc.toString - val dur = if duration == 0 then "-" else TimeUtils.pretty(duration) - val wasParsed = if parsed then "yes" else "no" - val gotCpg = if cpgGen then "yes" else "no" - Seq(lines, wasParsed, gotCpg, dur) + private case class ReportEntry(loc: Int, parsed: Boolean, cpgGen: Boolean, duration: Long): + def toSeq: Seq[String] = + val lines = loc.toString + val dur = if duration == 0 then "-" else TimeUtils.pretty(duration) + val wasParsed = if parsed then "yes" else "no" + val gotCpg = if cpgGen then "yes" else "no" + Seq(lines, wasParsed, gotCpg, dur) class Report: - import Report.* + import Report.* - private val reports: Reports = TrieMap.empty + private val reports: Reports = TrieMap.empty - private def formatTable(table: Seq[Seq[String]]): String = - if table.isEmpty then "" - else - // Get column widths based on the maximum cell width in each column (+2 for a one character padding on each side) - val colWidths = - table.transpose.map(_.map(cell => if cell == null then 0 else cell.length).max + 2) - // Format each row - val rows = table.map( - _.zip(colWidths) - .map { case (item, size) => s" %-${size - 1}s".format(item) } - .mkString("|", "|", "|") - ) - // Formatted separator row, used to separate the header and draw table borders - val separator = colWidths.map("-" * _).mkString("+", "+", "+") - // Put the table together and return - val header = rows.head - val content = rows.tail.take(rows.tail.size - 1) - val footer = rows.tail.last - (separator +: header +: separator +: content :+ separator :+ footer :+ separator) - .mkString("\n") - - def print(): Unit = - val rows = reports.toSeq - .sortBy(_._1) - .zipWithIndex - .view - .map { case ((file, sum), index) => - s"${index + 1}" +: file +: sum.toSeq - } - .toSeq - val numOfReports = reports.size - val header = Seq(Seq("#", "File", "LOC", "Parsed", "Got a CPG", "Duration")) - val footer = Seq( - Seq( - "Total", - "", - s"${reports.map(_._2.loc).sum}", - s"${reports.count(_._2.parsed)}/$numOfReports", - s"${reports.count(_._2.cpgGen)}/$numOfReports", - "" - ) + private def formatTable(table: Seq[Seq[String]]): String = + if table.isEmpty then "" + else + // Get column widths based on the maximum cell width in each column (+2 for a one character padding on each side) + val colWidths = + table.transpose.map(_.map(cell => if cell == null then 0 else cell.length).max + 2) + // Format each row + val rows = table.map( + _.zip(colWidths) + .map { case (item, size) => s" %-${size - 1}s".format(item) } + .mkString("|", "|", "|") ) - val table = header ++ rows ++ footer - logger.debug(s"Report:${System.lineSeparator()}${formatTable(table)}") - end print + // Formatted separator row, used to separate the header and draw table borders + val separator = colWidths.map("-" * _).mkString("+", "+", "+") + // Put the table together and return + val header = rows.head + val content = rows.tail.take(rows.tail.size - 1) + val footer = rows.tail.last + (separator +: header +: separator +: content :+ separator :+ footer :+ separator) + .mkString("\n") + + def print(): Unit = + val rows = reports.toSeq + .sortBy(_._1) + .zipWithIndex + .view + .map { case ((file, sum), index) => + s"${index + 1}" +: file +: sum.toSeq + } + .toSeq + val numOfReports = reports.size + val header = Seq(Seq("#", "File", "LOC", "Parsed", "Got a CPG", "Duration")) + val footer = Seq( + Seq( + "Total", + "", + s"${reports.map(_._2.loc).sum}", + s"${reports.count(_._2.parsed)}/$numOfReports", + s"${reports.count(_._2.cpgGen)}/$numOfReports", + "" + ) + ) + val table = header ++ rows ++ footer + logger.debug(s"Report:${System.lineSeparator()}${formatTable(table)}") + end print - def addReportInfo( - fileName: FileName, - loc: Int, - parsed: Boolean = false, - cpgGen: Boolean = false, - duration: Long = 0 - ): Unit = reports(fileName) = ReportEntry(loc, parsed, cpgGen, duration) + def addReportInfo( + fileName: FileName, + loc: Int, + parsed: Boolean = false, + cpgGen: Boolean = false, + duration: Long = 0 + ): Unit = reports(fileName) = ReportEntry(loc, parsed, cpgGen, duration) - def updateReport(fileName: FileName, cpg: Boolean, duration: Long): Unit = - reports.updateWith(fileName)(_.map(_.copy(cpgGen = cpg, duration = duration))) + def updateReport(fileName: FileName, cpg: Boolean, duration: Long): Unit = + reports.updateWith(fileName)(_.map(_.copy(cpgGen = cpg, duration = duration))) end Report diff --git a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/utils/TimeUtils.scala b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/utils/TimeUtils.scala index 86422f26..cb8d6bda 100644 --- a/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/utils/TimeUtils.scala +++ b/platform/frontends/c2cpg/src/main/scala/io/appthreat/c2cpg/utils/TimeUtils.scala @@ -5,48 +5,48 @@ import scala.concurrent.duration.* object TimeUtils: - /** Measures elapsed time for executing a block in nanoseconds */ - def time[R](block: => R): (R, Long) = - val t0 = System.nanoTime() - val result = block - val t1 = System.nanoTime() - val elapsed = t1 - t0 - (result, elapsed) - - /** Selects most appropriate TimeUnit for given duration and formats it accordingly */ - def pretty(duration: Long): String = pretty(Duration.fromNanos(duration)) - - private def pretty(duration: Duration): String = - duration match - case d: FiniteDuration => - val nanos = d.toNanos - val unit = chooseUnit(nanos) - val value = nanos.toDouble / NANOSECONDS.convert(1, unit) - - s"%.4g %s".formatLocal(Locale.ROOT, value, abbreviate(unit)) - - case Duration.MinusInf => s"-∞ (minus infinity)" - case Duration.Inf => s"∞ (infinity)" - case _ => "undefined" - - private def chooseUnit(nanos: Long): TimeUnit = - val d = nanos.nanos - - if d.toDays > 0 then DAYS - else if d.toHours > 0 then HOURS - else if d.toMinutes > 0 then MINUTES - else if d.toSeconds > 0 then SECONDS - else if d.toMillis > 0 then MILLISECONDS - else if d.toMicros > 0 then MICROSECONDS - else NANOSECONDS - - private def abbreviate(unit: TimeUnit): String = - unit match - case NANOSECONDS => "ns" - case MICROSECONDS => "μs" - case MILLISECONDS => "ms" - case SECONDS => "s" - case MINUTES => "min" - case HOURS => "h" - case DAYS => "d" + /** Measures elapsed time for executing a block in nanoseconds */ + def time[R](block: => R): (R, Long) = + val t0 = System.nanoTime() + val result = block + val t1 = System.nanoTime() + val elapsed = t1 - t0 + (result, elapsed) + + /** Selects most appropriate TimeUnit for given duration and formats it accordingly */ + def pretty(duration: Long): String = pretty(Duration.fromNanos(duration)) + + private def pretty(duration: Duration): String = + duration match + case d: FiniteDuration => + val nanos = d.toNanos + val unit = chooseUnit(nanos) + val value = nanos.toDouble / NANOSECONDS.convert(1, unit) + + s"%.4g %s".formatLocal(Locale.ROOT, value, abbreviate(unit)) + + case Duration.MinusInf => s"-∞ (minus infinity)" + case Duration.Inf => s"∞ (infinity)" + case _ => "undefined" + + private def chooseUnit(nanos: Long): TimeUnit = + val d = nanos.nanos + + if d.toDays > 0 then DAYS + else if d.toHours > 0 then HOURS + else if d.toMinutes > 0 then MINUTES + else if d.toSeconds > 0 then SECONDS + else if d.toMillis > 0 then MILLISECONDS + else if d.toMicros > 0 then MICROSECONDS + else NANOSECONDS + + private def abbreviate(unit: TimeUnit): String = + unit match + case NANOSECONDS => "ns" + case MICROSECONDS => "μs" + case MILLISECONDS => "ms" + case SECONDS => "s" + case MINUTES => "min" + case HOURS => "h" + case DAYS => "d" end TimeUtils diff --git a/platform/frontends/c2cpg/src/test/scala/io/appthreat/c2cpg/dataflow/DataFlowTests.scala b/platform/frontends/c2cpg/src/test/scala/io/appthreat/c2cpg/dataflow/DataFlowTests.scala index 0c8c9417..54df1a64 100644 --- a/platform/frontends/c2cpg/src/test/scala/io/appthreat/c2cpg/dataflow/DataFlowTests.scala +++ b/platform/frontends/c2cpg/src/test/scala/io/appthreat/c2cpg/dataflow/DataFlowTests.scala @@ -260,13 +260,13 @@ class DataFlowTests extends DataFlowCodeToCpgSuite { "Exclusions behind over-taint" should { "not kill flows" in { - cpg.method.name("sink").fullName.l shouldBe Seq("Test0.c:3:3:sink", "sink") + cpg.method.name("sink").fullName.l shouldBe Seq("sink") val source = cpg.method.name("source").methodReturn val sink = cpg.method.name("sink").parameter val flows = sink.reachableByFlows(source) flows.map(flowToResultPairs).toSetMutable shouldBe Set( - List(("RET", -1), ("source()", 6), ("sink(c[1])", 8), ("sink(p1)", -1)) + List(("RET", 2), ("source()", 6), ("sink(c[1])", 8), ("sink(int* cont)", 3)) ) } } @@ -291,7 +291,7 @@ class DataFlowTests extends DataFlowCodeToCpgSuite { val flows = sink.reachableByFlows(source) flows.map(flowToResultPairs).toSetMutable shouldBe Set( - List(("RET", -1), ("source()", 7), ("sink((*arg).field)", 8), ("sink(p1)", -1)) + List(("RET", 3), ("source()", 7), ("sink((*arg).field)", 8), ("sink(int i)", 4)) ) } } @@ -315,7 +315,7 @@ class DataFlowTests extends DataFlowCodeToCpgSuite { val flows = sink.reachableByFlows(source) flows.map(flowToResultPairs).toSetMutable shouldBe Set( - List(("RET", -1), ("source()", 6), ("sink(*arg)", 7), ("sink(p1)", -1)) + List(("RET", 2), ("source()", 6), ("sink(*arg)", 7), ("sink(int i)", 3)) ) } } diff --git a/platform/frontends/c2cpg/src/test/scala/io/appthreat/c2cpg/passes/ast/AstCreationPassTests.scala b/platform/frontends/c2cpg/src/test/scala/io/appthreat/c2cpg/passes/ast/AstCreationPassTests.scala index 4470c76f..90c277de 100644 --- a/platform/frontends/c2cpg/src/test/scala/io/appthreat/c2cpg/passes/ast/AstCreationPassTests.scala +++ b/platform/frontends/c2cpg/src/test/scala/io/appthreat/c2cpg/passes/ast/AstCreationPassTests.scala @@ -22,10 +22,10 @@ class AstCreationPassTests extends AbstractPassTest { |char *hello(); |""".stripMargin) { cpg => inside(cpg.method("foo").l) { case List(foo) => - foo.signature shouldBe "char* foo ()" + foo.signature shouldBe "char* ()" } inside(cpg.method("hello").l) { case List(hello) => - hello.signature shouldBe "char* hello ()" + hello.signature shouldBe "char*()" } } @@ -37,7 +37,7 @@ class AstCreationPassTests extends AbstractPassTest { "test.cpp" ) { cpg => inside(cpg.method("foo").l) { case List(m) => - m.signature shouldBe "void foo (int,int*)" + m.signature shouldBe "void (int,int*)" inside(m.parameter.l) { case List(x, args) => x.name shouldBe "x" x.code shouldBe "int x" @@ -124,32 +124,32 @@ class AstCreationPassTests extends AbstractPassTest { inside(cpg.method.fullNameExact(lambda1FullName).l) { case List(l1) => l1.name shouldBe lambda1FullName l1.code should startWith("[] (int a, int b) -> int") - l1.signature shouldBe "int anonymous_lambda_0 (int,int)" + l1.signature shouldBe "int (int,int)" l1.body.code shouldBe "{ return a + b; }" } inside(cpg.method.fullNameExact(lambda2FullName).l) { case List(l2) => l2.name shouldBe lambda2FullName l2.code should startWith("[] (string a, string b) -> string") - l2.signature shouldBe "string anonymous_lambda_1 (string,string)" + l2.signature shouldBe "string (string,string)" l2.body.code shouldBe "{ return a + b; }" } inside(cpg.typeDecl(NamespaceTraversal.globalNamespaceName).head.bindsOut.l) { case List(bX: Binding, bY: Binding) => bX.name shouldBe lambda1FullName - bX.signature shouldBe "int anonymous_lambda_0 (int,int)" + bX.signature shouldBe "int (int,int)" inside(bX.refOut.l) { case List(method: Method) => method.name shouldBe lambda1FullName method.fullName shouldBe lambda1FullName - method.signature shouldBe "int anonymous_lambda_0 (int,int)" + method.signature shouldBe "int (int,int)" } bY.name shouldBe lambda2FullName - bY.signature shouldBe "string anonymous_lambda_1 (string,string)" + bY.signature shouldBe "string (string,string)" inside(bY.refOut.l) { case List(method: Method) => method.name shouldBe lambda2FullName method.fullName shouldBe lambda2FullName - method.signature shouldBe "string anonymous_lambda_1 (string,string)" + method.signature shouldBe "string (string,string)" } } } @@ -168,7 +168,7 @@ class AstCreationPassTests extends AbstractPassTest { ) { cpg => val lambdaName = "anonymous_lambda_0" val lambdaFullName = "Foo.anonymous_lambda_0" - val signature = "int Foo.anonymous_lambda_0 (int,int)" + val signature = "int (int,int)" cpg.member.name("x").order.l shouldBe List(1) @@ -210,7 +210,7 @@ class AstCreationPassTests extends AbstractPassTest { ) { cpg => val lambdaName = "anonymous_lambda_0" val lambdaFullName = "A.B.Foo.anonymous_lambda_0" - val signature = "int A.B.Foo.anonymous_lambda_0 (int,int)" + val signature = "int (int,int)" cpg.member.name("x").order.l shouldBe List(1) @@ -253,9 +253,9 @@ class AstCreationPassTests extends AbstractPassTest { "test.cpp" ) { cpg => val lambda1Name = "anonymous_lambda_0" - val signature1 = "int anonymous_lambda_0 (int)" + val signature1 = "int (int)" val lambda2Name = "anonymous_lambda_1" - val signature2 = "int anonymous_lambda_1 (int)" + val signature2 = "int (int)" cpg.local.name("x").order.l shouldBe List(1) cpg.local.name("foo1").order.l shouldBe List(3) @@ -294,35 +294,6 @@ class AstCreationPassTests extends AbstractPassTest { } } - inside(cpg.call("x").l) { case List(lambda1call) => - lambda1call.name shouldBe "x" - lambda1call.methodFullName shouldBe "x" - lambda1call.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH - inside(lambda1call.astChildren.l) { case List(lit: Literal) => - lit.code shouldBe "10" - } - inside(lambda1call.argument.l) { case List(lit: Literal) => - lit.code shouldBe "10" - } - lambda1call.receiver.l shouldBe empty - } - - inside(cpg.call(lambda2Name).l) { case List(lambda2call) => - lambda2call.name shouldBe lambda2Name - lambda2call.methodFullName shouldBe lambda2Name - lambda2call.dispatchType shouldBe DispatchTypes.DYNAMIC_DISPATCH - inside(lambda2call.astChildren.l) { case List(ref: MethodRef, lit: Literal) => - ref.methodFullName shouldBe lambda2Name - ref.code should startWith("[](int n) -> int") - lit.code shouldBe "10" - } - - inside(lambda2call.argument.l) { case List(ref: MethodRef, lit: Literal) => - ref.methodFullName shouldBe lambda2Name - ref.code should startWith("[](int n) -> int") - lit.code shouldBe "10" - } - } } "be correct for empty method" in AstFixture("void method(int x) { }") { cpg => @@ -872,8 +843,8 @@ class AstCreationPassTests extends AbstractPassTest { |} """.stripMargin) { cpg => inside(cpg.method.name("main").ast.isCall.codeExact("(*strLenFunc)(\"123\")").l) { case List(call) => - call.name shouldBe "*strLenFunc" - call.methodFullName shouldBe "*strLenFunc" + call.name shouldBe ".pointerCall" + call.methodFullName shouldBe ".pointerCall" } } @@ -1101,43 +1072,6 @@ class AstCreationPassTests extends AbstractPassTest { .count(_.inheritsFromTypeFullName == List("Base")) shouldBe 1 } - "be correct for field access" in AstFixture( - """ - |class Foo { - |public: - | char x; - | int method(){return i;}; - |}; - | - |Foo f; - |int x = f.method(); - """.stripMargin, - "file.cpp" - ) { cpg => - cpg.typeDecl - .name("Foo") - .l - .size shouldBe 1 - - inside(cpg.call.code("f.method()").l) { case List(call: Call) => - call.methodFullName shouldBe Operators.fieldAccess - call.argument(1).code shouldBe "f" - call.argument(2).code shouldBe "method" - } - } - - "be correct for type initializer expression" in AstFixture( - """ - |int x = (int){ 1 }; - """.stripMargin, - "file.cpp" - ) { cpg => - inside(cpg.call.name(Operators.cast).l) { case List(call: Call) => - call.argument(2).code shouldBe "{ 1 }" - call.argument(1).code shouldBe "int" - } - } - "be correct for static assert" in AstFixture( """ |void foo(){ @@ -1528,7 +1462,7 @@ class AstCreationPassTests extends AbstractPassTest { argsAIdent.name shouldBe "a" argsAIdent.code shouldBe "a" argARef.order shouldBe 2 - argARef.methodFullName shouldBe "file.c:2:2:methodA.methodA" + argARef.methodFullName shouldBe "methodA" argARef.typeFullName shouldBe methodA.methodReturn.typeFullName val argsBIdent = callB.argument(1).asInstanceOf[Identifier] val argBRef = callB.argument(2).asInstanceOf[MethodRef] @@ -1536,7 +1470,7 @@ class AstCreationPassTests extends AbstractPassTest { argsBIdent.code shouldBe "b" argsBIdent.name shouldBe "b" argBRef.order shouldBe 2 - argBRef.methodFullName shouldBe "file.c:3:3:methodB.methodB" + argBRef.methodFullName shouldBe "methodB" argBRef.typeFullName shouldBe methodB.methodReturn.typeFullName } } @@ -1986,9 +1920,9 @@ class AstCreationPassTests extends AbstractPassTest { "be correct for function edge case" in AstFixture("class Foo { char (*(*x())[5])() }", "test.cpp") { cpg => val List(method) = cpg.method.nameNot("").l method.name shouldBe "x" - method.fullName shouldBe "test.cpp:1:1:Foo.x" + method.fullName shouldBe "Foo.x:char (* (*)[5])()()" method.code shouldBe "char (*(*x())[5])()" - method.signature shouldBe "char Foo.x ()" + method.signature shouldBe "char()" } "be consistent with pointer types" in AstFixture(""" diff --git a/platform/frontends/c2cpg/src/test/scala/io/appthreat/c2cpg/passes/ast/CallTests.scala b/platform/frontends/c2cpg/src/test/scala/io/appthreat/c2cpg/passes/ast/CallTests.scala index 16e8955a..26e657a2 100644 --- a/platform/frontends/c2cpg/src/test/scala/io/appthreat/c2cpg/passes/ast/CallTests.scala +++ b/platform/frontends/c2cpg/src/test/scala/io/appthreat/c2cpg/passes/ast/CallTests.scala @@ -1,10 +1,12 @@ package io.appthreat.c2cpg.passes.ast +import io.appthreat.c2cpg.astcreation.Defines import io.appthreat.c2cpg.testfixtures.CCodeToCpgSuite -import io.shiftleft.codepropertygraph.generated.Operators +import io.appthreat.x2cpg.Defines as X2CpgDefines +import io.shiftleft.codepropertygraph.generated.{DispatchTypes, Operators} import io.shiftleft.codepropertygraph.generated.nodes.Call import io.shiftleft.codepropertygraph.generated.nodes.Literal -import io.shiftleft.semanticcpg.language._ +import io.shiftleft.semanticcpg.language.* import io.shiftleft.semanticcpg.language.NoResolve class CallTests extends CCodeToCpgSuite { @@ -67,53 +69,6 @@ class CallTests extends CCodeToCpgSuite { } } - "CallTest 2" should { - val cpg = code( - """ - |using namespace std; - | - |class A{ - | public: - | int a; - |}; - | - |class B{ - | public: - | A* GetObj(); - |}; - | - |A* B::GetObj() { - | return nullptr; - |} - | - |class C{ - | public: - | A* GetObject(); - |}; - | - |A* C::GetObject() { - | B * b; - | return b->GetObj(); - |} - | - |bool Run(A *obj, C *c) { - | const A * a = c->GetObject(); - | a->a; - | return true; - |} - |""".stripMargin, - "code.cpp" - ) - - "have the correct callIn" in { - val List(m) = cpg.method.nameNot("").where(_.ast.isReturn.code(".*nullptr.*")).l - val List(c) = cpg.call.codeExact("b->GetObj()").l - c.callee.head shouldBe m - val List(callIn) = m.callIn.l - callIn.code shouldBe "b->GetObj()" - } - } - "CallTest 3" should { val cpg = code( """ @@ -127,8 +82,8 @@ class CallTests extends CCodeToCpgSuite { "test.cpp" ) "have correct names for static methods / calls" in { - cpg.method.name("square").fullName.head shouldBe "square" - cpg.method.name("call_square").call.methodFullName.head shouldBe "square" + cpg.method.name("square").fullName.head shouldBe "square:int(int)" + cpg.method.name("call_square").call.methodFullName.head shouldBe "square:int(int)" } } @@ -149,8 +104,8 @@ class CallTests extends CCodeToCpgSuite { "test.cpp" ) "have correct names for static methods / calls from classes" in { - cpg.method.name("square").fullName.head shouldBe "A.square" - cpg.method.name("call_square").call.methodFullName.head shouldBe "test.cpp:4:4:A.square.square.A.square" + cpg.method.name("square").fullName.head shouldBe "A.square:int(int)" + cpg.method.name("call_square").call.methodFullName.head shouldBe "A.square:int(int)" } } @@ -168,12 +123,452 @@ class CallTests extends CCodeToCpgSuite { ) "have correct type full names for calls" in { val List(bCall) = cpg.call.l - bCall.methodFullName shouldBe "A.b" + bCall.methodFullName shouldBe "A.b:void()" val List(bMethod) = cpg.method.name("b").internal.l - bMethod.fullName shouldBe "A.b" - bMethod.callIn.head shouldBe bCall - bCall.callee.head shouldBe bMethod + bMethod.fullName shouldBe "A.b:void()" + } + } + + "CallTest 6" should { + val cpg = code( + """ + |class A { + | public: + | void foo1(){ + | foo2(); + | } + | static void foo2() {} + |}; + | + |int main() { + | A a; + | a.foo1(); + |} + |""".stripMargin, + "test.cpp" + ) + "have correct type full names for calls" in { + val List(foo2Call) = cpg.call("foo2").l + foo2Call.methodFullName shouldBe "A.foo2:void()" + } + } + "Successfully typed calls" should { + "have correct call for call on non virtual class method" in { + val cpg = code( + """ + |namespace NNN { + | class A { + | public: + | void foo(int a){} + | }; + |} + | + |void outer() { + | NNN::A a; + | a.foo(1); + |} + |""".stripMargin, + "test.cpp" + ) + + val List(call) = cpg.call.nameExact("foo").l + call.signature shouldBe "void(int)" + call.methodFullName shouldBe "NNN.A.foo:void(int)" + call.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH + call.typeFullName shouldBe "void" + + val List(instArg, arg1) = call.argument.l + instArg.code shouldBe "a" + instArg.argumentIndex shouldBe 0 + arg1.code shouldBe "1" + arg1.argumentIndex shouldBe 1 + + call.receiver.isEmpty shouldBe true + } + + "have correct call for call on virtual class method" in { + val cpg = code( + """ + |namespace NNN { + | class A { + | public: + | virtual void foo(int a){} + | }; + |} + | + |void outer() { + | NNN::A a; + | a.foo(1); + |} + |""".stripMargin, + "test.cpp" + ) + + val List(call) = cpg.call.nameExact("foo").l + call.signature shouldBe "void(int)" + call.methodFullName shouldBe "NNN.A.foo:void(int)" + call.dispatchType shouldBe DispatchTypes.DYNAMIC_DISPATCH + call.typeFullName shouldBe "void" + + val List(instArg, arg1) = call.argument.l + instArg.code shouldBe "a" + instArg.argumentIndex shouldBe 0 + arg1.code shouldBe "1" + arg1.argumentIndex shouldBe 1 + + val List(receiver) = call.receiver.l + receiver shouldBe instArg + } + + "have correct call for call on stand alone method (CPP)" in { + val cpg = code( + """ + |namespace NNN { + | void foo(int a){} + |} + | + |void outer() { + | NNN::foo(1); + |} + |""".stripMargin, + "test.cpp" + ) + + val List(call) = cpg.call.nameExact("foo").l + call.signature shouldBe "void(int)" + call.methodFullName shouldBe "NNN.foo:void(int)" + call.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH + call.typeFullName shouldBe "void" + + val List(arg1) = call.argument.l + arg1.code shouldBe "1" + + call.receiver.isEmpty shouldBe true + } + + "have correct call for call on lambda function" in { + val cpg = code( + """ + |void outer() { + | [](int a) {}(1); + |} + |""".stripMargin, + "test.cpp" + ) + + val List(call) = cpg.call.nameExact("()").l + call.signature shouldBe "void(int)" + call.methodFullName shouldBe "():void(int)" + call.dispatchType shouldBe DispatchTypes.DYNAMIC_DISPATCH + call.typeFullName shouldBe "void" + + val List(arg1) = call.argument.l + arg1.code shouldBe "1" + arg1.argumentIndex shouldBe 1 + + val List(receiver) = call.receiver.l + receiver.isMethodRef shouldBe true + receiver.argumentIndex shouldBe -1 + } + + "have correct call for call on function pointer (CPP)" in { + val cpg = code( + """ + |class A { + | public: + | void (*foo)(int); + |}; + | + |void outer() { + | A a; + | a.foo(1); + |} + |""".stripMargin, + "test.cpp" + ) + + val List(call) = cpg.call.nameExact(Defines.operatorPointerCall).l + call.signature shouldBe "" + call.methodFullName shouldBe Defines.operatorPointerCall + call.dispatchType shouldBe DispatchTypes.DYNAMIC_DISPATCH + call.typeFullName shouldBe "void" + + val List(arg1) = call.argument.l + arg1.code shouldBe "1" + arg1.argumentIndex shouldBe 1 + + val List(receiver) = call.receiver.l + receiver.code shouldBe "a.foo" + receiver.argumentIndex shouldBe -1 + } + + "have correct call for call on callable object" in { + val cpg = code( + """ + |namespace NNN { + | class Callable { + | public: + | void operator()(int a){} + | }; + |} + |class A { + | public: + | NNN::Callable foo; + |}; + | + |void outer() { + | A a; + | a.foo(1); + |} + |""".stripMargin, + "test.cpp" + ) + + val List(call) = cpg.call.nameExact("()").l + call.signature shouldBe "void(int)" + call.methodFullName shouldBe "NNN.Callable.():void(int)" + call.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH + call.typeFullName shouldBe "void" + + val List(instArg, arg1) = call.argument.l + instArg.code shouldBe "a.foo" + instArg.argumentIndex shouldBe 0 + arg1.code shouldBe "1" + arg1.argumentIndex shouldBe 1 + + val List(receiver) = call.receiver.l + receiver shouldBe instArg + } + + "have correct call for call on function pointer (C)" in { + val cpg = code( + """ + |struct A { + | void (*foo)(int); + |} + |void outer() { + | struct A a; + | a.foo(1); + |} + |""".stripMargin, + "test.c" + ) + + val List(call) = cpg.call.nameExact(Defines.operatorPointerCall).l + call.signature shouldBe "" + call.methodFullName shouldBe Defines.operatorPointerCall + call.dispatchType shouldBe DispatchTypes.DYNAMIC_DISPATCH + call.typeFullName shouldBe "void" + + val List(arg1) = call.argument.l + arg1.code shouldBe "1" + arg1.argumentIndex shouldBe 1 + + val List(receiver) = call.receiver.l + receiver.code shouldBe "a.foo" + receiver.argumentIndex shouldBe -1 + } + + "have correct call for call on stand alone method (C)" in { + val cpg = code( + """ + |void foo(int) {} + |void outer() { + | foo(1); + |} + |""".stripMargin, + "test.c" + ) + + val List(call) = cpg.call.nameExact("foo").l + call.signature shouldBe "" + call.methodFullName shouldBe "foo" + call.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH + call.typeFullName shouldBe "void" + + val List(arg1) = call.argument.l + arg1.code shouldBe "1" + + call.receiver.isEmpty shouldBe true + } + + "have correct call for call on extern C function" in { + val cpg = code( + """ + |extern "C" { + | void foo(int); + |} + | + |void outer() { + | foo(1); + |} + |""".stripMargin, + "test.cpp" + ) + + val List(call) = cpg.call.nameExact("foo").l + call.signature shouldBe "" + call.methodFullName shouldBe "foo" + call.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH + call.typeFullName shouldBe "void" + + val List(arg1) = call.argument.l + arg1.code shouldBe "1" + arg1.argumentIndex shouldBe 1 + + call.receiver.isEmpty shouldBe true } } + "Not successfully typed calls" should { + "have correct call for field reference style call (CPP)" in { + val cpg = code( + """ + |void outer() { + | Unknown a; + | a.foo(1); + |} + |""".stripMargin, + "test.cpp" + ) + + val List(call) = cpg.call.nameExact("foo").l + call.signature shouldBe X2CpgDefines.UnresolvedSignature + call.methodFullName shouldBe s"${X2CpgDefines.UnresolvedNamespace}.foo:${X2CpgDefines.UnresolvedSignature}(1)" + call.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH + call.typeFullName shouldBe X2CpgDefines.Any + + val List(instArg, arg1) = call.argument.l + instArg.code shouldBe "a" + instArg.argumentIndex shouldBe 0 + arg1.code shouldBe "1" + arg1.argumentIndex shouldBe 1 + + val List(receiver) = call.receiver.l + receiver shouldBe instArg + } + + "have correct call for plain call (CPP)" in { + val cpg = code( + """ + |void outer() { + | foo(1); + |} + |""".stripMargin, + "test.cpp" + ) + + val List(call) = cpg.call.nameExact("foo").l + call.signature shouldBe X2CpgDefines.UnresolvedSignature + call.methodFullName shouldBe s"${X2CpgDefines.UnresolvedNamespace}.foo:${X2CpgDefines.UnresolvedSignature}(1)" + call.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH + call.typeFullName shouldBe X2CpgDefines.Any + + val List(arg1) = call.argument.l + arg1.code shouldBe "1" + arg1.argumentIndex shouldBe 1 + + call.receiver.isEmpty shouldBe true + } + + "have correct call for call on arbitrary expression (CPP)" in { + val cpg = code( + """ + |void outer() { + | getX()(1); + |} + |""".stripMargin, + "test.cpp" + ) + + val List(call) = cpg.call.nameExact("()").l + call.signature shouldBe X2CpgDefines.UnresolvedSignature + call.methodFullName shouldBe s"${X2CpgDefines.UnresolvedNamespace}.():${X2CpgDefines.UnresolvedSignature}(1)" + call.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH + call.typeFullName shouldBe X2CpgDefines.Any + + val List(instArg, arg1) = call.argument.l + instArg.code shouldBe "getX()" + instArg.argumentIndex shouldBe 0 + arg1.code shouldBe "1" + arg1.argumentIndex shouldBe 1 + + val List(receiver) = call.receiver.l + receiver shouldBe instArg + } + + "have correct call for field reference style call (C)" in { + val cpg = code( + """ + |void outer() { + | struct A a; + | a.foo(1); + |} + |""".stripMargin, + "test.c" + ) + + val List(call) = cpg.call.nameExact(Defines.operatorPointerCall).l + call.signature shouldBe "" + call.methodFullName shouldBe Defines.operatorPointerCall + call.dispatchType shouldBe DispatchTypes.DYNAMIC_DISPATCH + call.typeFullName shouldBe X2CpgDefines.Any + + val List(arg1) = call.argument.l + arg1.code shouldBe "1" + arg1.argumentIndex shouldBe 1 + + val List(receiver) = call.receiver.l + receiver.code shouldBe "a.foo" + receiver.argumentIndex shouldBe -1 + } + + "have correct call for plain call (C)" in { + val cpg = code( + """ + |void outer() { + | foo(1); + |} + |""".stripMargin, + "test.c" + ) + + val List(call) = cpg.call.nameExact("foo").l + call.signature shouldBe "" + call.methodFullName shouldBe s"foo" + call.dispatchType shouldBe DispatchTypes.STATIC_DISPATCH + call.typeFullName shouldBe X2CpgDefines.Any + + val List(arg1) = call.argument.l + arg1.code shouldBe "1" + arg1.argumentIndex shouldBe 1 + + call.receiver.isEmpty shouldBe true + } + + "have correct call for call on arbitrary expression (C)" in { + val cpg = code( + """ + |void outer() { + | getX()(1); + |} + |""".stripMargin, + "test.c" + ) + + val List(call) = cpg.call.nameExact(Defines.operatorPointerCall).l + call.signature shouldBe "" + call.methodFullName shouldBe Defines.operatorPointerCall + call.dispatchType shouldBe DispatchTypes.DYNAMIC_DISPATCH + call.typeFullName shouldBe X2CpgDefines.Any + + val List(arg1) = call.argument.l + arg1.code shouldBe "1" + arg1.argumentIndex shouldBe 1 + + val List(receiver) = call.receiver.l + receiver.code shouldBe "getX()" + receiver.argumentIndex shouldBe -1 + } + } } diff --git a/platform/frontends/c2cpg/src/test/scala/io/appthreat/c2cpg/passes/ast/HeaderAstCreationPassTests.scala b/platform/frontends/c2cpg/src/test/scala/io/appthreat/c2cpg/passes/ast/HeaderAstCreationPassTests.scala index 119003fc..bc0bce96 100644 --- a/platform/frontends/c2cpg/src/test/scala/io/appthreat/c2cpg/passes/ast/HeaderAstCreationPassTests.scala +++ b/platform/frontends/c2cpg/src/test/scala/io/appthreat/c2cpg/passes/ast/HeaderAstCreationPassTests.scala @@ -43,15 +43,15 @@ class HeaderAstCreationPassTests extends CCodeToCpgSuite { bar.fullName shouldBe "bar" bar.filename shouldBe "main.h" - foo.fullName shouldBe "main" - foo.filename shouldBe "main.c" + foo.fullName shouldBe "foo" + foo.filename shouldBe "other.h" // main is include twice. First time for the header file, // second time for the actual implementation in the source file // We do not de-duplicate this as line/column numbers differ - m1.fullName shouldBe "main.h:2:2:main" + m1.fullName shouldBe "main" m1.filename shouldBe "main.h" - m2.fullName shouldBe "other.h:2:2:foo" - m2.filename shouldBe "other.h" + m2.fullName shouldBe "main" + m2.filename shouldBe "main.c" printf.fullName shouldBe "printf" } } diff --git a/platform/frontends/c2cpg/src/test/scala/io/appthreat/c2cpg/passes/ast/MethodTests.scala b/platform/frontends/c2cpg/src/test/scala/io/appthreat/c2cpg/passes/ast/MethodTests.scala index 646eec38..01821f66 100644 --- a/platform/frontends/c2cpg/src/test/scala/io/appthreat/c2cpg/passes/ast/MethodTests.scala +++ b/platform/frontends/c2cpg/src/test/scala/io/appthreat/c2cpg/passes/ast/MethodTests.scala @@ -19,7 +19,7 @@ class MethodTests extends CCodeToCpgSuite { x.name shouldBe "main" x.fullName shouldBe "main" x.code should startWith("int main(int argc, char **argv) {") - x.signature shouldBe "int main (int,char**)" + x.signature shouldBe "int (int,char**)" x.isExternal shouldBe false x.order shouldBe 1 x.filename shouldBe "Test0.c" @@ -77,7 +77,7 @@ class MethodTests extends CCodeToCpgSuite { "should not generate a type decl for method declarations" in { inside(cpg.method.name("doFoo").l) { case List(x) => x.name shouldBe "doFoo" - x.fullName shouldBe "Test0.c:1:1:doFoo" + x.fullName shouldBe "doFoo" x.astParentType shouldBe NodeTypes.TYPE_DECL x.astParentFullName should endWith(NamespaceTraversal.globalNamespaceName) } @@ -112,8 +112,8 @@ class MethodTests extends CCodeToCpgSuite { "should be correct for methods with line breaks / whitespace" in { inside(cpg.method("foo").l) { case List(foo) => foo.name shouldBe "foo" - foo.fullName shouldBe "foo" - foo.signature shouldBe "void foo ()" + foo.fullName shouldBe "foo:void()" + foo.signature shouldBe "void ()" } } } @@ -129,7 +129,7 @@ class MethodTests extends CCodeToCpgSuite { val List(method) = cpg.method.nameExact("foo").l method.isExternal shouldBe false method.fullName shouldBe "foo" - method.signature shouldBe "int foo (int,int)" + method.signature shouldBe "int (int,int)" method.lineNumber shouldBe Option(2) method.columnNumber shouldBe Option(1) method.lineNumberEnd shouldBe Option(4) @@ -265,7 +265,7 @@ class MethodTests extends CCodeToCpgSuite { ) "deduplicate method forward declarations correctly" in { - cpg.method.fullName("abs").size shouldBe 1 + cpg.method.name("abs").size shouldBe 2 cpg.call.name("abs").callee(NoResolve).size shouldBe 1 } diff --git a/platform/frontends/c2cpg/src/test/scala/io/appthreat/c2cpg/passes/types/ClassTypeTests.scala b/platform/frontends/c2cpg/src/test/scala/io/appthreat/c2cpg/passes/types/ClassTypeTests.scala index d31956f2..f781a64c 100644 --- a/platform/frontends/c2cpg/src/test/scala/io/appthreat/c2cpg/passes/types/ClassTypeTests.scala +++ b/platform/frontends/c2cpg/src/test/scala/io/appthreat/c2cpg/passes/types/ClassTypeTests.scala @@ -163,8 +163,8 @@ class ClassTypeTests extends CCodeToCpgSuite(FileDefaults.CPP_EXT): | return 0; |}""".stripMargin ) - cpg.call("foo1").methodFullName.toSetMutable shouldBe Set("A.foo1") - cpg.call("foo2").methodFullName.toSetMutable shouldBe Set("B.foo2") + cpg.call("foo1").methodFullName.toSetMutable shouldBe Set("A.foo1:void()") + cpg.call("foo2").methodFullName.toSetMutable shouldBe Set("B.foo2:void()") } } @@ -181,7 +181,7 @@ class ClassTypeTests extends CCodeToCpgSuite(FileDefaults.CPP_EXT): |}""".stripMargin ) val List(constructor) = cpg.typeDecl.nameExact("FooT").method.isConstructor.l - constructor.signature shouldBe "Bar.Foo FooT.FooT (std.string,Bar.SomeClass)" + constructor.signature shouldBe "Bar.Foo (std.string,Bar.SomeClass)" val List(p1, p2) = constructor.parameter.l p1.typ.fullName shouldBe "std.string" p2.typ.fullName shouldBe "Bar.SomeClass" diff --git a/platform/frontends/c2cpg/src/test/scala/io/appthreat/c2cpg/passes/types/NamespaceTypeTests.scala b/platform/frontends/c2cpg/src/test/scala/io/appthreat/c2cpg/passes/types/NamespaceTypeTests.scala index dcb94d0f..56ee2f2f 100644 --- a/platform/frontends/c2cpg/src/test/scala/io/appthreat/c2cpg/passes/types/NamespaceTypeTests.scala +++ b/platform/frontends/c2cpg/src/test/scala/io/appthreat/c2cpg/passes/types/NamespaceTypeTests.scala @@ -35,8 +35,8 @@ class NamespaceTypeTests extends CCodeToCpgSuite(fileSuffix = FileDefaults.CPP_E |} |""".stripMargin) inside(cpg.method.isNotStub.fullName.l) { case List(f, m) => - f shouldBe "Q.V.f" - m shouldBe "Q.V.C.m" + f shouldBe "Q.V.f:int()" + m shouldBe "Q.V.C.m:int()" } inside(cpg.namespaceBlock.nameNot("").l) { case List(q, v) => @@ -50,9 +50,9 @@ class NamespaceTypeTests extends CCodeToCpgSuite(fileSuffix = FileDefaults.CPP_E inside(q.method.l) { case List(f, m) => f.name shouldBe "f" - f.fullName shouldBe "Q.V.f" + f.fullName shouldBe "Q.V.f:int()" m.name shouldBe "m" - m.fullName shouldBe "Q.V.C.m" + m.fullName shouldBe "Q.V.C.m:int()" } } @@ -80,9 +80,9 @@ class NamespaceTypeTests extends CCodeToCpgSuite(fileSuffix = FileDefaults.CPP_E inside(cpg.method.nameNot("").fullName.l) { case List(m1, f1, f2, h, m2) => // TODO: this looks strange too it first glance. But as Eclipse CDT does not provide any // mapping from definitions outside of namespace into them we cant reconstruct proper full-names. - m1 shouldBe "Test0.cpp:3:3:Q.V.C.m" - f1 shouldBe "Test0.cpp:5:5:Q.V.f" - h shouldBe "Test0.cpp:11:11:V.f.h" + m1 shouldBe "Q.V.C.m:int()" + f1 shouldBe "Q.V.f:int()" + h shouldBe "h:void()" f2 shouldBe "V.f" m2 shouldBe "V.C.m" } @@ -129,9 +129,9 @@ class NamespaceTypeTests extends CCodeToCpgSuite(fileSuffix = FileDefaults.CPP_E } inside(cpg.method.internal.nameNot("").fullName.l) { case List(f, g, h) => - f shouldBe "f" - g shouldBe "A.g" - h shouldBe "h" + f shouldBe "f:void()" + g shouldBe "A.g:void()" + h shouldBe "h:void()" } inside(cpg.method.nameExact("h").ast.isCall.code.l) { case List(c1, c2, c3, c4) => @@ -165,16 +165,16 @@ class NamespaceTypeTests extends CCodeToCpgSuite(fileSuffix = FileDefaults.CPP_E } inside(cpg.method.internal.nameNot("").fullName.l) { case List(f, g, h) => - f shouldBe "Test0.cpp:2:2:f" - g shouldBe "Test0.cpp:5:5:A.g" - h shouldBe "h" + f shouldBe "f:void()" + g shouldBe "A.g:void()" + h shouldBe "h:void()" } inside(cpg.call.filterNot(_.name == Operators.fieldAccess).l) { case List(f, g) => - f.name shouldBe "X.f" - f.methodFullName shouldBe "X.f" - g.name shouldBe "X.g" - g.methodFullName shouldBe "X.g" + f.name shouldBe "f" + f.methodFullName shouldBe "f:void()" + g.name shouldBe "g" + g.methodFullName shouldBe "A.g:void()" } } @@ -204,19 +204,19 @@ class NamespaceTypeTests extends CCodeToCpgSuite(fileSuffix = FileDefaults.CPP_E } inside(cpg.method.internal.nameNot("").l) { case List(f1, f2, foo, bar) => - f1.fullName shouldBe "Test0.cpp:3:3:A.f" - f1.signature shouldBe "void A.f (int)" - f2.fullName shouldBe "Test0.cpp:8:8:A.f" - f2.signature shouldBe "void A.f (char)" - foo.fullName shouldBe "foo" - bar.fullName shouldBe "bar" + f1.fullName shouldBe "A.f:void(int)" + f1.signature shouldBe "void(int)" + f2.fullName shouldBe "A.f:void(char)" + f2.signature shouldBe "void(char)" + foo.fullName shouldBe "foo:void()" + bar.fullName shouldBe "bar:void()" } inside(cpg.call.l) { case List(c1, c2) => c1.name shouldBe "f" - c1.methodFullName shouldBe "f" + c1.methodFullName shouldBe "A.f:void(char)" c2.name shouldBe "f" - c2.methodFullName shouldBe "f" + c2.methodFullName shouldBe "A.f:void(char)" } } diff --git a/platform/frontends/c2cpg/src/test/scala/io/appthreat/c2cpg/passes/types/TemplateTypeTests.scala b/platform/frontends/c2cpg/src/test/scala/io/appthreat/c2cpg/passes/types/TemplateTypeTests.scala index 26edaf22..70d2a91e 100644 --- a/platform/frontends/c2cpg/src/test/scala/io/appthreat/c2cpg/passes/types/TemplateTypeTests.scala +++ b/platform/frontends/c2cpg/src/test/scala/io/appthreat/c2cpg/passes/types/TemplateTypeTests.scala @@ -72,11 +72,11 @@ class TemplateTypeTests extends CCodeToCpgSuite(fileSuffix = FileDefaults.CPP_EX |""".stripMargin) inside(cpg.method.nameNot("").internal.l) { case List(x, y) => x.name shouldBe "x" - x.fullName shouldBe "x" - x.signature shouldBe "void x (T,U)" + x.fullName shouldBe "x:void(#0,#1)" + x.signature shouldBe "void (T,U)" y.name shouldBe "y" - y.fullName shouldBe "Test0.cpp:6:6:y" - y.signature shouldBe "void y (T,U)" + y.fullName shouldBe "y:void(#0,#1)" + y.signature shouldBe "void(T,U)" } } diff --git a/platform/frontends/c2cpg/src/test/scala/io/appthreat/c2cpg/querying/CallGraphQueryTests.scala b/platform/frontends/c2cpg/src/test/scala/io/appthreat/c2cpg/querying/CallGraphQueryTests.scala index 297ddc4c..91615f59 100644 --- a/platform/frontends/c2cpg/src/test/scala/io/appthreat/c2cpg/querying/CallGraphQueryTests.scala +++ b/platform/frontends/c2cpg/src/test/scala/io/appthreat/c2cpg/querying/CallGraphQueryTests.scala @@ -83,10 +83,10 @@ class CallGraphQueryTests extends CCodeToCpgSuite { "test.cpp" ) "have correct type full names for calls" in { - cpg.method.name("eat").fullName.toSet shouldBe Set("Animal.eat", "Cat.eat", "People.eat") - cpg.method.fullName("main").callee.fullName.toSet shouldBe Set("Animal.eat", "Cat.eat") - cpg.method.fullName("Cat.eat").caller.fullName.toSet shouldBe Set("main") - cpg.method.fullName("Animal.eat").caller.fullName.toSet shouldBe Set("People.feedPets", "main") + cpg.method.name("eat").fullName.toSet shouldBe Set("Animal.eat:void()", "Cat.eat:void()", "People.eat:void()") + cpg.method.name("main").callee.fullName.toSet shouldBe Set("Animal.eat:void()", "Cat.eat:void()", ".addressOf") + cpg.method.fullNameExact("Cat.eat:void()").caller.fullName.toSet shouldBe Set("People.feedPets:void()", "main:int()") + cpg.method.fullName("Animal.eat.*").caller.fullName.toSet shouldBe Set("People.feedPets:void()", "main:int()") } } } diff --git a/platform/frontends/javasrc2cpg/build.sbt b/platform/frontends/javasrc2cpg/build.sbt index bf83efb5..07ef0350 100644 --- a/platform/frontends/javasrc2cpg/build.sbt +++ b/platform/frontends/javasrc2cpg/build.sbt @@ -19,7 +19,7 @@ Global / onChangedBuildSource := ReloadOnSourceChanges lazy val packTestCode = taskKey[Unit]("Packs test code for JarTypeReader into jars.") packTestCode := { - import better.files._ + import better.files.* import net.lingala.zip4j.ZipFile import net.lingala.zip4j.model.ZipParameters import net.lingala.zip4j.model.enums.{CompressionLevel, CompressionMethod} diff --git a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/JavaSrc2Cpg.scala b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/JavaSrc2Cpg.scala index e939aeff..ccc36d07 100644 --- a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/JavaSrc2Cpg.scala +++ b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/JavaSrc2Cpg.scala @@ -20,59 +20,59 @@ import scala.util.Try import scala.util.matching.Regex class JavaSrc2Cpg extends X2CpgFrontend[Config]: - import JavaSrc2Cpg.* + import JavaSrc2Cpg.* - override def createCpg(config: Config): Try[Cpg] = - withNewEmptyCpg(config.outputPath, config: Config) { (cpg, config) => - new MetaDataPass(cpg, language, config.inputPath).createAndApply() - val astCreationPass = new AstCreationPass(config, cpg) - astCreationPass.createAndApply() - astCreationPass.clearJavaParserCaches() - new ConfigFileCreationPass(cpg).createAndApply() - if !config.skipTypeInfPass then - TypeNodePass.withRegisteredTypes( - astCreationPass.global.usedTypes.keys().asScala.toList, - cpg - ).createAndApply() - new TypeInferencePass(cpg).createAndApply() - } + override def createCpg(config: Config): Try[Cpg] = + withNewEmptyCpg(config.outputPath, config: Config) { (cpg, config) => + new MetaDataPass(cpg, language, config.inputPath).createAndApply() + val astCreationPass = new AstCreationPass(config, cpg) + astCreationPass.createAndApply() + astCreationPass.clearJavaParserCaches() + new ConfigFileCreationPass(cpg).createAndApply() + if !config.skipTypeInfPass then + TypeNodePass.withRegisteredTypes( + astCreationPass.global.usedTypes.keys().asScala.toList, + cpg + ).createAndApply() + new TypeInferencePass(cpg).createAndApply() + } object JavaSrc2Cpg: - val language: String = Languages.JAVASRC + val language: String = Languages.JAVASRC - val sourceFileExtensions: Set[String] = Set(".java") + val sourceFileExtensions: Set[String] = Set(".java") - val DefaultIgnoredFilesRegex: List[Regex] = List(".git", ".mvn", "test").flatMap { directory => - List(s"(^|\\\\)$directory($$|\\\\)".r.unanchored, s"(^|/)$directory($$|/)".r.unanchored) - } + val DefaultIgnoredFilesRegex: List[Regex] = List(".git", ".mvn", "test").flatMap { directory => + List(s"(^|\\\\)$directory($$|\\\\)".r.unanchored, s"(^|/)$directory($$|/)".r.unanchored) + } - val DefaultConfig: Config = - Config().withDefaultIgnoredFilesRegex(DefaultIgnoredFilesRegex) + val DefaultConfig: Config = + Config().withDefaultIgnoredFilesRegex(DefaultIgnoredFilesRegex) - def apply(): JavaSrc2Cpg = new JavaSrc2Cpg() + def apply(): JavaSrc2Cpg = new JavaSrc2Cpg() - def typeRecoveryPasses(cpg: Cpg, config: Option[Config] = None): List[CpgPassBase] = - List( - new JavaTypeRecoveryPass( - cpg, - XTypeRecoveryConfig(enabledDummyTypes = !config.exists(_.disableDummyTypes)) - ), - new JavaTypeHintCallLinker(cpg) - ) + def typeRecoveryPasses(cpg: Cpg, config: Option[Config] = None): List[CpgPassBase] = + List( + new JavaTypeRecoveryPass( + cpg, + XTypeRecoveryConfig(enabledDummyTypes = !config.exists(_.disableDummyTypes)) + ), + new JavaTypeHintCallLinker(cpg) + ) - def showEnv(): Unit = - val value = - JavaSrcEnvVar.values.foreach { envVar => - val currentValue = Option(System.getenv(envVar.name)).getOrElse("") - println(s"${envVar.name}:") - println(s" Description : ${envVar.description}") - println(s" Current value: $currentValue") - } + def showEnv(): Unit = + val value = + JavaSrcEnvVar.values.foreach { envVar => + val currentValue = Option(System.getenv(envVar.name)).getOrElse("") + println(s"${envVar.name}:") + println(s" Description : ${envVar.description}") + println(s" Current value: $currentValue") + } - enum JavaSrcEnvVar(val name: String, val description: String): - case JdkPath - extends JavaSrcEnvVar( - "JAVASRC_JDK_PATH", - "Path to the JDK home used for retrieving type information about builtin Java types." - ) + enum JavaSrcEnvVar(val name: String, val description: String): + case JdkPath + extends JavaSrcEnvVar( + "JAVASRC_JDK_PATH", + "Path to the JDK home used for retrieving type information about builtin Java types." + ) end JavaSrc2Cpg diff --git a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/Main.scala b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/Main.scala index 6a2c8ed5..63a78013 100644 --- a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/Main.scala +++ b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/Main.scala @@ -23,112 +23,112 @@ final case class Config( dumpJavaparserAsts: Boolean = false ) extends X2CpgConfig[Config] with TypeRecoveryParserConfig[Config]: - def withInferenceJarPaths(paths: Set[String]): Config = - copy(inferenceJarPaths = paths).withInheritedFields(this) + def withInferenceJarPaths(paths: Set[String]): Config = + copy(inferenceJarPaths = paths).withInheritedFields(this) - def withFetchDependencies(value: Boolean): Config = - copy(fetchDependencies = value).withInheritedFields(this) + def withFetchDependencies(value: Boolean): Config = + copy(fetchDependencies = value).withInheritedFields(this) - def withJavaFeatureSetVersion(version: String): Config = - copy(javaFeatureSetVersion = Some(version)).withInheritedFields(this) + def withJavaFeatureSetVersion(version: String): Config = + copy(javaFeatureSetVersion = Some(version)).withInheritedFields(this) - def withDelombokJavaHome(path: String): Config = - copy(delombokJavaHome = Some(path)).withInheritedFields(this) + def withDelombokJavaHome(path: String): Config = + copy(delombokJavaHome = Some(path)).withInheritedFields(this) - def withDelombokMode(mode: String): Config = - copy(delombokMode = Some(mode)).withInheritedFields(this) + def withDelombokMode(mode: String): Config = + copy(delombokMode = Some(mode)).withInheritedFields(this) - def withEnableTypeRecovery(value: Boolean): Config = - copy(enableTypeRecovery = value).withInheritedFields(this) + def withEnableTypeRecovery(value: Boolean): Config = + copy(enableTypeRecovery = value).withInheritedFields(this) - def withJdkPath(path: String): Config = - copy(jdkPath = Some(path)).withInheritedFields(this) + def withJdkPath(path: String): Config = + copy(jdkPath = Some(path)).withInheritedFields(this) - def withShowEnv(value: Boolean): Config = - copy(showEnv = value).withInheritedFields(this) + def withShowEnv(value: Boolean): Config = + copy(showEnv = value).withInheritedFields(this) - def withSkipTypeInfPass(value: Boolean): Config = - copy(skipTypeInfPass = value).withInheritedFields(this) + def withSkipTypeInfPass(value: Boolean): Config = + copy(skipTypeInfPass = value).withInheritedFields(this) - def withDumpJavaparserAsts(value: Boolean): Config = - copy(dumpJavaparserAsts = value).withInheritedFields(this) + def withDumpJavaparserAsts(value: Boolean): Config = + copy(dumpJavaparserAsts = value).withInheritedFields(this) end Config private object Frontend: - implicit val defaultConfig: Config = JavaSrc2Cpg.DefaultConfig - - val cmdLineParser: OParser[Unit, Config] = - val builder = OParser.builder[Config] - import builder.* - OParser.sequence( - programName("javasrc2cpg"), - opt[String]("inference-jar-paths") - .text(s"extra jars used only for type information") - .action((path, c) => c.withInferenceJarPaths(c.inferenceJarPaths + path)), - opt[Unit]("fetch-dependencies") - .text("attempt to fetch dependencies jars for extra type information") - .action((_, c) => c.withFetchDependencies(true)), - opt[String]("delombok-java-home") - .text( - "Optional override to set java home used to run Delombok. Java 17 is recommended for the best results." - ) - .action((path, c) => c.withDelombokJavaHome(path)), - opt[String]("delombok-mode") - .text( - """Specifies how delombok should be executed. Options are + implicit val defaultConfig: Config = JavaSrc2Cpg.DefaultConfig + + val cmdLineParser: OParser[Unit, Config] = + val builder = OParser.builder[Config] + import builder.* + OParser.sequence( + programName("javasrc2cpg"), + opt[String]("inference-jar-paths") + .text(s"extra jars used only for type information") + .action((path, c) => c.withInferenceJarPaths(c.inferenceJarPaths + path)), + opt[Unit]("fetch-dependencies") + .text("attempt to fetch dependencies jars for extra type information") + .action((_, c) => c.withFetchDependencies(true)), + opt[String]("delombok-java-home") + .text( + "Optional override to set java home used to run Delombok. Java 17 is recommended for the best results." + ) + .action((path, c) => c.withDelombokJavaHome(path)), + opt[String]("delombok-mode") + .text( + """Specifies how delombok should be executed. Options are | no-delombok => do not use delombok for analysis or type information. | default => run delombok if a lombok dependency is found and analyse delomboked code. | types-only => run delombok, but use it for type information only | run-delombok => run delombok and use delomboked source for both analysis and type information.""".stripMargin - ) - .action((mode, c) => c.withDelombokMode(mode)), - opt[Unit]("enable-type-recovery") - .hidden() - .action((_, c) => c.withEnableTypeRecovery(true)) - .text("enable generic type recovery"), - XTypeRecovery.parserOptions, - opt[String]("jdk-path") - .action((path, c) => c.withJdkPath(path)) - .text( - "JDK used for resolving builtin Java types. If not set, current classpath will be used" - ), - opt[Unit]("show-env") - .action((_, c) => c.withShowEnv(true)) - .text("print information about environment variables used by javasrc2cpg and exit."), - opt[Unit]("skip-type-inf-pass") - .hidden() - .action((_, c) => c.withSkipTypeInfPass(true)) - .text( - "Skip the type inference pass. Results will be much worse, so should only be used for development purposes" - ), - opt[Unit]("dump-javaparser-asts") - .hidden() - .action((_, c) => c.withDumpJavaparserAsts(true)) - .text( - "Dump the javaparser asts for the given input files and terminate (for debugging)." - ) - ) - end cmdLineParser + ) + .action((mode, c) => c.withDelombokMode(mode)), + opt[Unit]("enable-type-recovery") + .hidden() + .action((_, c) => c.withEnableTypeRecovery(true)) + .text("enable generic type recovery"), + XTypeRecovery.parserOptions, + opt[String]("jdk-path") + .action((path, c) => c.withJdkPath(path)) + .text( + "JDK used for resolving builtin Java types. If not set, current classpath will be used" + ), + opt[Unit]("show-env") + .action((_, c) => c.withShowEnv(true)) + .text("print information about environment variables used by javasrc2cpg and exit."), + opt[Unit]("skip-type-inf-pass") + .hidden() + .action((_, c) => c.withSkipTypeInfPass(true)) + .text( + "Skip the type inference pass. Results will be much worse, so should only be used for development purposes" + ), + opt[Unit]("dump-javaparser-asts") + .hidden() + .action((_, c) => c.withDumpJavaparserAsts(true)) + .text( + "Dump the javaparser asts for the given input files and terminate (for debugging)." + ) + ) + end cmdLineParser end Frontend object Main extends X2CpgMain(cmdLineParser, new JavaSrc2Cpg()): - override def main(args: Array[String]): Unit = - // TODO: This is a hack to allow users to use the "--show-env" option without having - // to specify an input argument. Clean this up when adding this option to more frontends. - if args.contains("--show-env") then - super.main(Array("--show-env", "")) - else - super.main(args) - - def run(config: Config, javasrc2Cpg: JavaSrc2Cpg): Unit = - if config.showEnv then - JavaSrc2Cpg.showEnv() - else if config.dumpJavaparserAsts then - JavaParserAstPrinter.printJpAsts(config) - else - javasrc2Cpg.run(config) - - def getCmdLineParser: OParser[Unit, Config] = cmdLineParser + override def main(args: Array[String]): Unit = + // TODO: This is a hack to allow users to use the "--show-env" option without having + // to specify an input argument. Clean this up when adding this option to more frontends. + if args.contains("--show-env") then + super.main(Array("--show-env", "")) + else + super.main(args) + + def run(config: Config, javasrc2Cpg: JavaSrc2Cpg): Unit = + if config.showEnv then + JavaSrc2Cpg.showEnv() + else if config.dumpJavaparserAsts then + JavaParserAstPrinter.printJpAsts(config) + else + javasrc2Cpg.run(config) + + def getCmdLineParser: OParser[Unit, Config] = cmdLineParser end Main diff --git a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/jartypereader/JarTypeReader.scala b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/jartypereader/JarTypeReader.scala index 7b7cf7d8..cadb036a 100644 --- a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/jartypereader/JarTypeReader.scala +++ b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/jartypereader/JarTypeReader.scala @@ -27,103 +27,103 @@ import scala.util.{Failure, Success, Try} object JarTypeReader: - private val logger = LoggerFactory.getLogger(this.getClass) - - val ObjectTypeSignature: ClassTypeSignature = - val packageSpecifier = Some("java.lang") - val signature = NameWithTypeArgs(name = "Object", typeArguments = Nil) - ClassTypeSignature(packageSpecifier, signature, suffix = Nil) - - def getTypes(jarPath: String): List[ResolvedTypeDecl] = - val cp = new ClassPool() - cp.insertClassPath(jarPath) - - val jarFile = new JarFile(jarPath) - jarFile - .entries() - .asScala - .filter(_.getName.endsWith(".class")) - .map { entry => - entry.getRealName.replace("/", ".").dropRight(6) // Drop .class extension - } - .flatMap(getTypeDeclForEntry(cp, _)) - .toList - - def getTypeEntryForMethod(method: CtMethod, parentDecl: ResolvedTypeDecl): ResolvedMethod = - val name = method.getName - val signatureDescriptor = Option(method.getGenericSignature).getOrElse(method.getSignature) - val signature = DescriptorParser.parseMethodSignature(signatureDescriptor) - val isAbstract = method.isEmpty - - model.ResolvedMethod(name, parentDecl, signature, isAbstract) - - /** This should only be used for classes without a generic signature. Generic type information - * will be missing otherwise. - */ - private def classTypeSignatureFromString(signature: String): ClassTypeSignature = - signature.split('.').toList match - case Nil => - logger.debug(s"$signature is not a valid class signature") - ClassTypeSignature(None, NameWithTypeArgs("", Nil), Nil) - - case name :: Nil => - ClassTypeSignature(None, typedName = NameWithTypeArgs(name, Nil), Nil) - - case nameWithPkg => - val packagePrefix = nameWithPkg.init.mkString(".") - val name = nameWithPkg.last - ClassTypeSignature(Some(packagePrefix), NameWithTypeArgs(name, Nil), Nil) - - /** This should only be used for classes without a generic signature. Generic type information - * will be missing otherwise. - */ - private def getCtClassSignature(ctClass: CtClass): ClassSignature = - val classFile = ctClass.getClassFile2 - val typeParameters = Nil - val superclassSignature = - Option.unless(classFile.isInterface)( - classTypeSignatureFromString(classFile.getSuperclass) + private val logger = LoggerFactory.getLogger(this.getClass) + + val ObjectTypeSignature: ClassTypeSignature = + val packageSpecifier = Some("java.lang") + val signature = NameWithTypeArgs(name = "Object", typeArguments = Nil) + ClassTypeSignature(packageSpecifier, signature, suffix = Nil) + + def getTypes(jarPath: String): List[ResolvedTypeDecl] = + val cp = new ClassPool() + cp.insertClassPath(jarPath) + + val jarFile = new JarFile(jarPath) + jarFile + .entries() + .asScala + .filter(_.getName.endsWith(".class")) + .map { entry => + entry.getRealName.replace("/", ".").dropRight(6) // Drop .class extension + } + .flatMap(getTypeDeclForEntry(cp, _)) + .toList + + def getTypeEntryForMethod(method: CtMethod, parentDecl: ResolvedTypeDecl): ResolvedMethod = + val name = method.getName + val signatureDescriptor = Option(method.getGenericSignature).getOrElse(method.getSignature) + val signature = DescriptorParser.parseMethodSignature(signatureDescriptor) + val isAbstract = method.isEmpty + + model.ResolvedMethod(name, parentDecl, signature, isAbstract) + + /** This should only be used for classes without a generic signature. Generic type information + * will be missing otherwise. + */ + private def classTypeSignatureFromString(signature: String): ClassTypeSignature = + signature.split('.').toList match + case Nil => + logger.debug(s"$signature is not a valid class signature") + ClassTypeSignature(None, NameWithTypeArgs("", Nil), Nil) + + case name :: Nil => + ClassTypeSignature(None, typedName = NameWithTypeArgs(name, Nil), Nil) + + case nameWithPkg => + val packagePrefix = nameWithPkg.init.mkString(".") + val name = nameWithPkg.last + ClassTypeSignature(Some(packagePrefix), NameWithTypeArgs(name, Nil), Nil) + + /** This should only be used for classes without a generic signature. Generic type information + * will be missing otherwise. + */ + private def getCtClassSignature(ctClass: CtClass): ClassSignature = + val classFile = ctClass.getClassFile2 + val typeParameters = Nil + val superclassSignature = + Option.unless(classFile.isInterface)( + classTypeSignatureFromString(classFile.getSuperclass) + ) + val interfacesSignatures = classFile.getInterfaces.map(classTypeSignatureFromString).toList + + ClassSignature(typeParameters, superclassSignature, interfacesSignatures) + + private def getResolvedField(ctField: CtField): ResolvedVariableType = + val name = ctField.getName + val signatureDescriptor = + Option(ctField.getGenericSignature).getOrElse(ctField.getSignature) + val signature = DescriptorParser.parseFieldSignature(signatureDescriptor) + + model.ResolvedVariableType(name, signature) + + def getTypeDeclForEntry(cp: ClassPool, name: String): Option[ResolvedTypeDecl] = + Try(cp.get(name)) match + case Success(ctClass) => + val name = ctClass.getSimpleName + val packageSpecifier = ctClass.getPackageName + val signature = Option(ctClass.getGenericSignature) + .map(DescriptorParser.parseClassSignature) + .getOrElse(getCtClassSignature(ctClass)) + val isInterface = ctClass.isInterface + val isAbstract = !isInterface && ctClass.getClassFile2.isAbstract + val fields = ctClass.getFields.map(getResolvedField).toList + val typeDecl = model.ResolvedTypeDecl( + name, + Some(packageSpecifier), + signature, + isInterface, + isAbstract, + fields ) - val interfacesSignatures = classFile.getInterfaces.map(classTypeSignatureFromString).toList - - ClassSignature(typeParameters, superclassSignature, interfacesSignatures) - - private def getResolvedField(ctField: CtField): ResolvedVariableType = - val name = ctField.getName - val signatureDescriptor = - Option(ctField.getGenericSignature).getOrElse(ctField.getSignature) - val signature = DescriptorParser.parseFieldSignature(signatureDescriptor) - - model.ResolvedVariableType(name, signature) - - def getTypeDeclForEntry(cp: ClassPool, name: String): Option[ResolvedTypeDecl] = - Try(cp.get(name)) match - case Success(ctClass) => - val name = ctClass.getSimpleName - val packageSpecifier = ctClass.getPackageName - val signature = Option(ctClass.getGenericSignature) - .map(DescriptorParser.parseClassSignature) - .getOrElse(getCtClassSignature(ctClass)) - val isInterface = ctClass.isInterface - val isAbstract = !isInterface && ctClass.getClassFile2.isAbstract - val fields = ctClass.getFields.map(getResolvedField).toList - val typeDecl = model.ResolvedTypeDecl( - name, - Some(packageSpecifier), - signature, - isInterface, - isAbstract, - fields - ) - val methods = ctClass.getMethods.map(getTypeEntryForMethod(_, typeDecl)).toList - typeDecl.addMethods(methods) - - Some(typeDecl) - - case Failure(_: NotFoundException) => - // TODO: Expected for interfaces, but fix this - None - - case Failure(_) => - None + val methods = ctClass.getMethods.map(getTypeEntryForMethod(_, typeDecl)).toList + typeDecl.addMethods(methods) + + Some(typeDecl) + + case Failure(_: NotFoundException) => + // TODO: Expected for interfaces, but fix this + None + + case Failure(_) => + None end JarTypeReader diff --git a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/jartypereader/descriptorparser/DescriptorParser.scala b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/jartypereader/descriptorparser/DescriptorParser.scala index dcc341e9..e7b10ae6 100644 --- a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/jartypereader/descriptorparser/DescriptorParser.scala +++ b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/jartypereader/descriptorparser/DescriptorParser.scala @@ -12,30 +12,30 @@ import scala.util.parsing.combinator.{Parsers, RegexParsers} object DescriptorParser extends TypeParser: - private val logger = LoggerFactory.getLogger(this.getClass) + private val logger = LoggerFactory.getLogger(this.getClass) - def parseMethodSignature(descriptor: String): MethodSignature = - parseSignature(methodSignature, descriptor) + def parseMethodSignature(descriptor: String): MethodSignature = + parseSignature(methodSignature, descriptor) - def parseClassSignature(descriptor: String): ClassSignature = - parseSignature(classSignature, descriptor) + def parseClassSignature(descriptor: String): ClassSignature = + parseSignature(classSignature, descriptor) - def parseFieldSignature(descriptor: String): ReferenceTypeSignature = - parseSignature(fieldSignature, descriptor) + def parseFieldSignature(descriptor: String): ReferenceTypeSignature = + parseSignature(fieldSignature, descriptor) - private def parseSignature[T](parser: Parser[T], descriptor: String): T = - parse(parser, descriptor) match - case Success(signature, _) => signature + private def parseSignature[T](parser: Parser[T], descriptor: String): T = + parse(parser, descriptor) match + case Success(signature, _) => signature - case Failure(err, _) => - logger.debug(s"parseClassSignature failed with $err") - throw new IllegalArgumentException( - s"FAILURE: Parsing invalid signature descriptor $descriptor" - ) + case Failure(err, _) => + logger.debug(s"parseClassSignature failed with $err") + throw new IllegalArgumentException( + s"FAILURE: Parsing invalid signature descriptor $descriptor" + ) - case Error(err, _) => - logger.debug(s"parseClassSignature raised error $err") - throw new IllegalArgumentException( - s"ERROR: Parsing invalid signature descriptor $descriptor" - ) + case Error(err, _) => + logger.debug(s"parseClassSignature raised error $err") + throw new IllegalArgumentException( + s"ERROR: Parsing invalid signature descriptor $descriptor" + ) end DescriptorParser diff --git a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/jartypereader/descriptorparser/TokenParser.scala b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/jartypereader/descriptorparser/TokenParser.scala index b82240f3..41b6168b 100644 --- a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/jartypereader/descriptorparser/TokenParser.scala +++ b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/jartypereader/descriptorparser/TokenParser.scala @@ -7,54 +7,54 @@ import org.slf4j.LoggerFactory import scala.util.parsing.combinator.RegexParsers trait TokenParser extends RegexParsers: - private val logger = LoggerFactory.getLogger(classOf[TokenParser]) + private val logger = LoggerFactory.getLogger(classOf[TokenParser]) - case object Colon - case object Semicolon - case object Slash - case object ClassTypeStart - case object TypeVarStart - case object ArrayStart - case object Dot - case object OpenParen - case object CloseParen - case object OpenAngle - case object CloseAngle - case object Caret + case object Colon + case object Semicolon + case object Slash + case object ClassTypeStart + case object TypeVarStart + case object ArrayStart + case object Dot + case object OpenParen + case object CloseParen + case object OpenAngle + case object CloseAngle + case object Caret - private def translateBaseType(descriptor: String): String = - descriptor match - case "B" => Byte - case "C" => Char - case "D" => Double - case "F" => Float - case "I" => Int - case "J" => Long - case "S" => Short - case "Z" => Boolean - // Void is not a BaseType, but sort of acts like one in method return signatures. Since - // we expect valid descriptors (and types in general) as input, treating void as a base - // type simplifies the model without introducing a significant possibility for confusion or error. - case "V" => Void - case unk => - throw new IllegalArgumentException(s"$unk is not a valid base type descriptor.") + private def translateBaseType(descriptor: String): String = + descriptor match + case "B" => Byte + case "C" => Char + case "D" => Double + case "F" => Float + case "I" => Int + case "J" => Long + case "S" => Short + case "Z" => Boolean + // Void is not a BaseType, but sort of acts like one in method return signatures. Since + // we expect valid descriptors (and types in general) as input, treating void as a base + // type simplifies the model without introducing a significant possibility for confusion or error. + case "V" => Void + case unk => + throw new IllegalArgumentException(s"$unk is not a valid base type descriptor.") - def baseType: Parser[PrimitiveType] = "[BCDFIJSZ]".r ^^ { t => - PrimitiveType(translateBaseType(t)) - } - def voidDescriptor: Parser[PrimitiveType] = "V" ^^ { t => PrimitiveType(translateBaseType(t)) } - def identifier: Parser[String] = raw"[^.\[/<>:;]+".r ^^ identity - def wildcardIndicator: Parser[String] = "[+-]".r ^^ identity - def colon: Parser[Colon.type] = ":" ^^ (_ => Colon) - def semicolon: Parser[Semicolon.type] = ";" ^^ (_ => Semicolon) - def slash: Parser[Slash.type] = "/" ^^ (_ => Slash) - def classTypeStart: Parser[ClassTypeStart.type] = "L" ^^ (_ => ClassTypeStart) - def typeVarStart: Parser[TypeVarStart.type] = "T" ^^ (_ => TypeVarStart) - def arrayStart: Parser[ArrayStart.type] = "[" ^^ (_ => ArrayStart) - def dot: Parser[Dot.type] = "." ^^ (_ => Dot) - def openParen: Parser[OpenParen.type] = "(" ^^ (_ => OpenParen) - def closeParen: Parser[CloseParen.type] = ")" ^^ (_ => CloseParen) - def openAngle: Parser[OpenAngle.type] = "<" ^^ (_ => OpenAngle) - def closeAngle: Parser[CloseAngle.type] = ">" ^^ (_ => CloseAngle) - def caret: Parser[Caret.type] = "^" ^^ (_ => Caret) + def baseType: Parser[PrimitiveType] = "[BCDFIJSZ]".r ^^ { t => + PrimitiveType(translateBaseType(t)) + } + def voidDescriptor: Parser[PrimitiveType] = "V" ^^ { t => PrimitiveType(translateBaseType(t)) } + def identifier: Parser[String] = raw"[^.\[/<>:;]+".r ^^ identity + def wildcardIndicator: Parser[String] = "[+-]".r ^^ identity + def colon: Parser[Colon.type] = ":" ^^ (_ => Colon) + def semicolon: Parser[Semicolon.type] = ";" ^^ (_ => Semicolon) + def slash: Parser[Slash.type] = "/" ^^ (_ => Slash) + def classTypeStart: Parser[ClassTypeStart.type] = "L" ^^ (_ => ClassTypeStart) + def typeVarStart: Parser[TypeVarStart.type] = "T" ^^ (_ => TypeVarStart) + def arrayStart: Parser[ArrayStart.type] = "[" ^^ (_ => ArrayStart) + def dot: Parser[Dot.type] = "." ^^ (_ => Dot) + def openParen: Parser[OpenParen.type] = "(" ^^ (_ => OpenParen) + def closeParen: Parser[CloseParen.type] = ")" ^^ (_ => CloseParen) + def openAngle: Parser[OpenAngle.type] = "<" ^^ (_ => OpenAngle) + def closeAngle: Parser[CloseAngle.type] = ">" ^^ (_ => CloseAngle) + def caret: Parser[Caret.type] = "^" ^^ (_ => Caret) end TokenParser diff --git a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/jartypereader/descriptorparser/TypeParser.scala b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/jartypereader/descriptorparser/TypeParser.scala index ac6826f9..e0eedf72 100644 --- a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/jartypereader/descriptorparser/TypeParser.scala +++ b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/jartypereader/descriptorparser/TypeParser.scala @@ -22,99 +22,99 @@ import org.slf4j.LoggerFactory trait TypeParser extends TokenParser: - private val logger = LoggerFactory.getLogger(this.getClass) + private val logger = LoggerFactory.getLogger(this.getClass) - def unboundWildcard: Parser[UnboundWildcard.type] = "*" ^^ (_ => UnboundWildcard) + def unboundWildcard: Parser[UnboundWildcard.type] = "*" ^^ (_ => UnboundWildcard) - def typeVarSignature: Parser[TypeVariableSignature] = - (typeVarStart ~ identifier ~ semicolon) ^^ { case _ ~ name ~ _ => - TypeVariableSignature(name) - } + def typeVarSignature: Parser[TypeVariableSignature] = + (typeVarStart ~ identifier ~ semicolon) ^^ { case _ ~ name ~ _ => + TypeVariableSignature(name) + } - def packageSpecifier: Parser[String] = - // TODO Expect these to be short, but potential room for optimizing string concatenation if performance is an issue. - (identifier ~ slash ~ rep(packageSpecifier)) ^^ { - case firstId ~ _ ~ Nil => firstId - case firstId ~ _ ~ otherSpecifiers => firstId + "." + otherSpecifiers.mkString - } + def packageSpecifier: Parser[String] = + // TODO Expect these to be short, but potential room for optimizing string concatenation if performance is an issue. + (identifier ~ slash ~ rep(packageSpecifier)) ^^ { + case firstId ~ _ ~ Nil => firstId + case firstId ~ _ ~ otherSpecifiers => firstId + "." + otherSpecifiers.mkString + } - def typeArguments: Parser[List[TypeArgument]] = - (openAngle ~ rep1(typeArgument) ~ closeAngle) ^^ { case _ ~ typeArgs ~ _ => typeArgs } + def typeArguments: Parser[List[TypeArgument]] = + (openAngle ~ rep1(typeArgument) ~ closeAngle) ^^ { case _ ~ typeArgs ~ _ => typeArgs } - def typeArgument: Parser[TypeArgument] = - val maybeBoundTypeArgument = (opt(wildcardIndicator) ~ referenceTypeSignature) ^^ { - case Some("-") ~ typeSignature => BoundWildcard(BoundBelow, typeSignature) + def typeArgument: Parser[TypeArgument] = + val maybeBoundTypeArgument = (opt(wildcardIndicator) ~ referenceTypeSignature) ^^ { + case Some("-") ~ typeSignature => BoundWildcard(BoundBelow, typeSignature) - case Some("+") ~ typeSignature => model.BoundWildcard(BoundAbove, typeSignature) + case Some("+") ~ typeSignature => model.BoundWildcard(BoundAbove, typeSignature) - case Some(symbol) ~ typeSignature => - logger.debug(s"Invalid wildcard indicator `$symbol`. Treating as unbound wildcard") - UnboundWildcard + case Some(symbol) ~ typeSignature => + logger.debug(s"Invalid wildcard indicator `$symbol`. Treating as unbound wildcard") + UnboundWildcard - case None ~ typeSignature => SimpleTypeArgument(typeSignature) - } + case None ~ typeSignature => SimpleTypeArgument(typeSignature) + } - maybeBoundTypeArgument | unboundWildcard + maybeBoundTypeArgument | unboundWildcard - def simpleClassTypeSignature: Parser[NameWithTypeArgs] = - (identifier ~ opt(typeArguments)) ^^ { case name ~ maybeTypes => - model.NameWithTypeArgs(name, maybeTypes.getOrElse(Nil)) - } + def simpleClassTypeSignature: Parser[NameWithTypeArgs] = + (identifier ~ opt(typeArguments)) ^^ { case name ~ maybeTypes => + model.NameWithTypeArgs(name, maybeTypes.getOrElse(Nil)) + } - def classTypeSignatureSuffix: Parser[NameWithTypeArgs] = - (dot ~ simpleClassTypeSignature) ^^ { case _ ~ sig => sig } + def classTypeSignatureSuffix: Parser[NameWithTypeArgs] = + (dot ~ simpleClassTypeSignature) ^^ { case _ ~ sig => sig } - def classTypeSignature: Parser[ClassTypeSignature] = - (classTypeStart ~ opt(packageSpecifier) ~ simpleClassTypeSignature ~ rep( - classTypeSignatureSuffix - ) ~ semicolon) ^^ { - case _ ~ packageSpecifier ~ signature ~ suffix ~ _ => - ClassTypeSignature(packageSpecifier, signature, suffix) - } + def classTypeSignature: Parser[ClassTypeSignature] = + (classTypeStart ~ opt(packageSpecifier) ~ simpleClassTypeSignature ~ rep( + classTypeSignatureSuffix + ) ~ semicolon) ^^ { + case _ ~ packageSpecifier ~ signature ~ suffix ~ _ => + ClassTypeSignature(packageSpecifier, signature, suffix) + } - def arrayTypeSignature: Parser[ArrayTypeSignature] = - (arrayStart ~ javaTypeSignature) ^^ { case _ ~ sig => ArrayTypeSignature(sig) } + def arrayTypeSignature: Parser[ArrayTypeSignature] = + (arrayStart ~ javaTypeSignature) ^^ { case _ ~ sig => ArrayTypeSignature(sig) } - def javaTypeSignature: Parser[JavaTypeSignature] = referenceTypeSignature | baseType + def javaTypeSignature: Parser[JavaTypeSignature] = referenceTypeSignature | baseType - def returnType: Parser[JavaTypeSignature] = javaTypeSignature | voidDescriptor + def returnType: Parser[JavaTypeSignature] = javaTypeSignature | voidDescriptor - def referenceTypeSignature: Parser[ReferenceTypeSignature] = - classTypeSignature | typeVarSignature | arrayTypeSignature + def referenceTypeSignature: Parser[ReferenceTypeSignature] = + classTypeSignature | typeVarSignature | arrayTypeSignature - def classBound: Parser[Option[ReferenceTypeSignature]] = - (colon ~ opt(referenceTypeSignature)) ^^ { case _ ~ maybeSignature => maybeSignature } + def classBound: Parser[Option[ReferenceTypeSignature]] = + (colon ~ opt(referenceTypeSignature)) ^^ { case _ ~ maybeSignature => maybeSignature } - def interfaceBound: Parser[ReferenceTypeSignature] = - (colon ~ referenceTypeSignature) ^^ { case _ ~ signature => signature } + def interfaceBound: Parser[ReferenceTypeSignature] = + (colon ~ referenceTypeSignature) ^^ { case _ ~ signature => signature } - def typeParameter: Parser[TypeParameter] = - (identifier ~ classBound ~ rep(interfaceBound)) ^^ { - case name ~ classBound ~ interfaceBounds => - TypeParameter(name, classBound, interfaceBounds) - } + def typeParameter: Parser[TypeParameter] = + (identifier ~ classBound ~ rep(interfaceBound)) ^^ { + case name ~ classBound ~ interfaceBounds => + TypeParameter(name, classBound, interfaceBounds) + } - def typeParameters: Parser[List[TypeParameter]] = - (openAngle ~ rep1(typeParameter) ~ closeAngle) ^^ { case _ ~ params ~ _ => params } + def typeParameters: Parser[List[TypeParameter]] = + (openAngle ~ rep1(typeParameter) ~ closeAngle) ^^ { case _ ~ params ~ _ => params } - def classSignature: Parser[ClassSignature] = - (opt(typeParameters) ~ classTypeSignature ~ rep(classTypeSignature)) ^^ { - case typeParams ~ supClass ~ supInterfaces => - ClassSignature(typeParams.getOrElse(Nil), Some(supClass), supInterfaces) - } + def classSignature: Parser[ClassSignature] = + (opt(typeParameters) ~ classTypeSignature ~ rep(classTypeSignature)) ^^ { + case typeParams ~ supClass ~ supInterfaces => + ClassSignature(typeParams.getOrElse(Nil), Some(supClass), supInterfaces) + } - def methodSignature: Parser[MethodSignature] = - (opt(typeParameters) ~ openParen ~ rep(javaTypeSignature) ~ closeParen ~ returnType ~ rep( - javaTypeSignature - )) ^^ { - case typeParams ~ _ ~ paramTypes ~ _ ~ returnSignature ~ throwsSignatures => - MethodSignature( - typeParams.getOrElse(Nil), - paramTypes, - returnSignature, - throwsSignatures - ) - } + def methodSignature: Parser[MethodSignature] = + (opt(typeParameters) ~ openParen ~ rep(javaTypeSignature) ~ closeParen ~ returnType ~ rep( + javaTypeSignature + )) ^^ { + case typeParams ~ _ ~ paramTypes ~ _ ~ returnSignature ~ throwsSignatures => + MethodSignature( + typeParams.getOrElse(Nil), + paramTypes, + returnSignature, + throwsSignatures + ) + } - def fieldSignature: Parser[ReferenceTypeSignature] = referenceTypeSignature + def fieldSignature: Parser[ReferenceTypeSignature] = referenceTypeSignature end TypeParser diff --git a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/jartypereader/model/Model.scala b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/jartypereader/model/Model.scala index 093ab18c..59f52784 100644 --- a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/jartypereader/model/Model.scala +++ b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/jartypereader/model/Model.scala @@ -3,24 +3,24 @@ package io.appthreat.javasrc2cpg.jartypereader.model import Bound.Bound sealed trait Named: - def name: String - def qualifiedName: String + def name: String + def qualifiedName: String - override def equals(obj: Any): Boolean = - obj match - case other: Named => - this.name == other.name && this.qualifiedName == other.qualifiedName - case _ => false + override def equals(obj: Any): Boolean = + obj match + case other: Named => + this.name == other.name && this.qualifiedName == other.qualifiedName + case _ => false - def buildQualifiedClassName(name: String, packageSpecifier: Option[String]): String = - packageSpecifier.map(ps => s"$ps.$name").getOrElse(name) + def buildQualifiedClassName(name: String, packageSpecifier: Option[String]): String = + packageSpecifier.map(ps => s"$ps.$name").getOrElse(name) case class NameWithTypeArgs(name: String, typeArguments: List[TypeArgument]) object Bound: - sealed trait Bound - case object BoundAbove extends Bound - case object BoundBelow extends Bound + sealed trait Bound + case object BoundAbove extends Bound + case object BoundBelow extends Bound sealed trait TypeArgument case class SimpleTypeArgument(typeSignature: ReferenceTypeSignature) extends TypeArgument @@ -29,8 +29,8 @@ case object UnboundWildcard ex sealed trait JavaTypeSignature extends Named case class PrimitiveType(fullName: String) extends JavaTypeSignature: - override val name: String = fullName - override val qualifiedName: String = fullName + override val name: String = fullName + override val qualifiedName: String = fullName sealed trait ReferenceTypeSignature extends JavaTypeSignature @@ -39,16 +39,16 @@ case class ClassTypeSignature( typedName: NameWithTypeArgs, suffix: List[NameWithTypeArgs] ) extends ReferenceTypeSignature: - override val name: String = (typedName :: suffix).map(_.name).mkString(".") - override val qualifiedName: String = buildQualifiedClassName(name, packageSpecifier) + override val name: String = (typedName :: suffix).map(_.name).mkString(".") + override val qualifiedName: String = buildQualifiedClassName(name, packageSpecifier) case class TypeVariableSignature(identifier: String) extends ReferenceTypeSignature: - override val name: String = identifier - override val qualifiedName: String = identifier + override val name: String = identifier + override val qualifiedName: String = identifier case class ArrayTypeSignature(signature: JavaTypeSignature) extends ReferenceTypeSignature: - private def buildArrayName(baseName: String): String = s"[$baseName]" - override val name: String = buildArrayName(signature.name) - override val qualifiedName: String = buildArrayName(signature.qualifiedName) + private def buildArrayName(baseName: String): String = s"[$baseName]" + override val name: String = buildArrayName(signature.name) + override val qualifiedName: String = buildArrayName(signature.qualifiedName) case class TypeParameter( name: String, @@ -70,8 +70,8 @@ case class MethodSignature( sealed trait ResolvedType extends Named object Unresolved extends ResolvedType: - override val name: String = "Unresolved" - override val qualifiedName: String = name + override val name: String = "Unresolved" + override val qualifiedName: String = name class ResolvedTypeDecl( override val name: String, @@ -82,33 +82,33 @@ class ResolvedTypeDecl( val fields: List[ResolvedVariableType], initDeclaredMethods: List[ResolvedMethod] ) extends ResolvedType: - override val qualifiedName: String = buildQualifiedClassName(name, packageSpecifier) + override val qualifiedName: String = buildQualifiedClassName(name, packageSpecifier) - private var declaredMethods = initDeclaredMethods + private var declaredMethods = initDeclaredMethods - def getDeclaredMethods: List[ResolvedMethod] = declaredMethods + def getDeclaredMethods: List[ResolvedMethod] = declaredMethods - private[jartypereader] def addMethods(newMethods: List[ResolvedMethod]): Unit = - declaredMethods ++= newMethods + private[jartypereader] def addMethods(newMethods: List[ResolvedMethod]): Unit = + declaredMethods ++= newMethods object ResolvedTypeDecl: - def apply( - name: String, - packageSpecifier: Option[String], - signature: ClassSignature, - isInterface: Boolean, - isAbstract: Boolean, - fields: List[ResolvedVariableType], - methods: List[ResolvedMethod] = Nil - ): ResolvedTypeDecl = - new ResolvedTypeDecl( - name, - packageSpecifier, - signature, - isInterface, - isAbstract, - fields, - methods - ) + def apply( + name: String, + packageSpecifier: Option[String], + signature: ClassSignature, + isInterface: Boolean, + isAbstract: Boolean, + fields: List[ResolvedVariableType], + methods: List[ResolvedMethod] = Nil + ): ResolvedTypeDecl = + new ResolvedTypeDecl( + name, + packageSpecifier, + signature, + isInterface, + isAbstract, + fields, + methods + ) case class ResolvedMethod( override val name: String, @@ -116,26 +116,26 @@ case class ResolvedMethod( signature: MethodSignature, isAbstract: Boolean ) extends ResolvedType: - override val qualifiedName: String = s"${parentTypeDecl.qualifiedName}.$name" + override val qualifiedName: String = s"${parentTypeDecl.qualifiedName}.$name" case class ResolvedVariableType(name: String, signature: ReferenceTypeSignature) extends ResolvedType: - override val qualifiedName: String = name + override val qualifiedName: String = name object Model: - // TODO: This is a duplicate of the TypeConstants object in TypeInfoCalculator. Remove the other one once - // we switch to the new solver. - object TypeConstants: - val Byte: String = "byte" - val Short: String = "short" - val Int: String = "int" - val Long: String = "long" - val Float: String = "float" - val Double: String = "double" - val Char: String = "char" - val Boolean: String = "boolean" - val Object: String = "java.lang.Object" - val Class: String = "java.lang.Class" - val Iterator: String = "java.util.Iterator" - val Void: String = "void" - val Any: String = "ANY" + // TODO: This is a duplicate of the TypeConstants object in TypeInfoCalculator. Remove the other one once + // we switch to the new solver. + object TypeConstants: + val Byte: String = "byte" + val Short: String = "short" + val Int: String = "int" + val Long: String = "long" + val Float: String = "float" + val Double: String = "double" + val Char: String = "char" + val Boolean: String = "boolean" + val Object: String = "java.lang.Object" + val Class: String = "java.lang.Class" + val Iterator: String = "java.util.Iterator" + val Void: String = "void" + val Any: String = "ANY" diff --git a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/jpastprinter/JavaParserAstPrinter.scala b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/jpastprinter/JavaParserAstPrinter.scala index ff0e5df2..7ca7a00b 100644 --- a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/jpastprinter/JavaParserAstPrinter.scala +++ b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/jpastprinter/JavaParserAstPrinter.scala @@ -9,15 +9,15 @@ import io.shiftleft.semanticcpg.language.dotextension.Shared import java.nio.file.Path object JavaParserAstPrinter: - def printJpAsts(config: Config): Unit = + def printJpAsts(config: Config): Unit = - val sourceParser = util.SourceParser(config, false) - val printer = new YamlPrinter(true) + val sourceParser = util.SourceParser(config, false) + val printer = new YamlPrinter(true) - SourceParser.getSourceFilenames(config).foreach { filename => - val relativeFilename = Path.of(config.inputPath).relativize(Path.of(filename)).toString - sourceParser.parseAnalysisFile(relativeFilename).foreach { compilationUnit => - println(relativeFilename) - println(printer.output(compilationUnit)) - } - } + SourceParser.getSourceFilenames(config).foreach { filename => + val relativeFilename = Path.of(config.inputPath).relativize(Path.of(filename)).toString + sourceParser.parseAnalysisFile(relativeFilename).foreach { compilationUnit => + println(relativeFilename) + println(printer.output(compilationUnit)) + } + } diff --git a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/passes/AstCreationPass.scala b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/passes/AstCreationPass.scala index 036cf720..ba66a01e 100644 --- a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/passes/AstCreationPass.scala +++ b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/passes/AstCreationPass.scala @@ -26,120 +26,120 @@ import scala.util.{Success, Try} class AstCreationPass(config: Config, cpg: Cpg, sourcesOverride: Option[List[String]] = None) extends ConcurrentWriterCpgPass[String](cpg): - val global: Global = new Global() - - private val sourceFilenames = SourceParser.getSourceFilenames(config, sourcesOverride) - - val (sourceParser, symbolSolver, packagesJarMappings) = - initParserAndUtils(config, sourceFilenames) - - override def generateParts(): Array[String] = sourceFilenames - - override def runOnPart(diffGraph: DiffGraphBuilder, filename: String): Unit = - val relativeFilename = Path.of(config.inputPath).relativize(Path.of(filename)).toString - sourceParser.parseAnalysisFile(relativeFilename) match - case Some(compilationUnit) => - symbolSolver.inject(compilationUnit) - diffGraph.absorb( - new AstCreator( - relativeFilename, - compilationUnit, - global, - symbolSolver, - packagesJarMappings - )( - config.schemaValidation - ).createAst() - ) - - case None => - - /** Clear JavaParser caches. Should only be invoked after we no longer need JavaParser, e.g. as - * soon as we've built the AST layer for all files. - */ - def clearJavaParserCaches(): Unit = - JavaParserFacade.clearInstances() - - private def initParserAndUtils( - config: Config, - sourceFilenames: Array[String] - ): (SourceParser, JavaSymbolSolver, mutable.Map[String, mutable.Set[String]]) = - val dependencies = getDependencyList(config.inputPath) - val sourceParser = util.SourceParser(config, dependencies.exists(_.contains("lombok"))) - val (symbolSolver, packagesJarMappings) = - createSymbolSolver(config, dependencies, sourceParser, sourceFilenames) - (sourceParser, symbolSolver, packagesJarMappings) - - private def getDependencyList(inputPath: String): List[String] = - if config.fetchDependencies then - DependencyResolver.getDependencies(Paths.get(inputPath)) match - case Some(deps) => deps.toList - case None => - List() - else - List() - - private def createSymbolSolver( - config: Config, - dependencies: List[String], - sourceParser: SourceParser, - sourceFilenames: Array[String] - ): (JavaSymbolSolver, mutable.Map[String, mutable.Set[String]]) = - val combinedTypeSolver = new SimpleCombinedTypeSolver() - val symbolSolver = new JavaSymbolSolver(combinedTypeSolver) - - val jdkPathFromEnvVar = Option(System.getenv(JavaSrcEnvVar.JdkPath.name)) - val jdkPath = (config.jdkPath, jdkPathFromEnvVar) match - case (None, None) => - val javaHome = System.getProperty("java.home") - if javaHome != null && javaHome.nonEmpty then javaHome - else System.getenv("JAVA_HOME") - case (None, Some(jdkPath)) => - jdkPath - - case (Some(jdkPath), _) => - jdkPath - var jdkJarTypeSolver: JdkJarTypeSolver = null - // native-image could have empty JAVA_HOME - if jdkPath != null && jdkPath.nonEmpty then - jdkJarTypeSolver = JdkJarTypeSolver.fromJdkPath(jdkPath) - combinedTypeSolver.addNonCachingTypeSolver(jdkJarTypeSolver) - val relativeSourceFilenames = - sourceFilenames.map(filename => - Path.of(config.inputPath).relativize(Path.of(filename)).toString - ) - - val sourceTypeSolver = - EagerSourceTypeSolver( - relativeSourceFilenames, - sourceParser, - combinedTypeSolver, - symbolSolver - ) - combinedTypeSolver.addCachingTypeSolver(sourceTypeSolver) - combinedTypeSolver.addNonCachingTypeSolver(new ReflectionTypeSolver()) - // Add solvers for inference jars - val jarsList = config.inferenceJarPaths.flatMap(recursiveJarsFromPath).toList - (jarsList ++ dependencies) - .flatMap { path => - Try(new JarTypeSolver(path)).toOption - } - .foreach { combinedTypeSolver.addNonCachingTypeSolver(_) } - if jdkJarTypeSolver != null then (symbolSolver, jdkJarTypeSolver.packagesJarMappings) - else (symbolSolver, null) - end createSymbolSolver - - private def recursiveJarsFromPath(path: String): List[String] = - Try(File(path)) match - case Success(file) if file.isDirectory => - file.listRecursively - .map(_.canonicalPath) - .filter(_.endsWith(".jar")) - .toList - - case Success(file) if file.canonicalPath.endsWith(".jar") => - List(file.canonicalPath) - - case _ => - Nil + val global: Global = new Global() + + private val sourceFilenames = SourceParser.getSourceFilenames(config, sourcesOverride) + + val (sourceParser, symbolSolver, packagesJarMappings) = + initParserAndUtils(config, sourceFilenames) + + override def generateParts(): Array[String] = sourceFilenames + + override def runOnPart(diffGraph: DiffGraphBuilder, filename: String): Unit = + val relativeFilename = Path.of(config.inputPath).relativize(Path.of(filename)).toString + sourceParser.parseAnalysisFile(relativeFilename) match + case Some(compilationUnit) => + symbolSolver.inject(compilationUnit) + diffGraph.absorb( + new AstCreator( + relativeFilename, + compilationUnit, + global, + symbolSolver, + packagesJarMappings + )( + config.schemaValidation + ).createAst() + ) + + case None => + + /** Clear JavaParser caches. Should only be invoked after we no longer need JavaParser, e.g. as + * soon as we've built the AST layer for all files. + */ + def clearJavaParserCaches(): Unit = + JavaParserFacade.clearInstances() + + private def initParserAndUtils( + config: Config, + sourceFilenames: Array[String] + ): (SourceParser, JavaSymbolSolver, mutable.Map[String, mutable.Set[String]]) = + val dependencies = getDependencyList(config.inputPath) + val sourceParser = util.SourceParser(config, dependencies.exists(_.contains("lombok"))) + val (symbolSolver, packagesJarMappings) = + createSymbolSolver(config, dependencies, sourceParser, sourceFilenames) + (sourceParser, symbolSolver, packagesJarMappings) + + private def getDependencyList(inputPath: String): List[String] = + if config.fetchDependencies then + DependencyResolver.getDependencies(Paths.get(inputPath)) match + case Some(deps) => deps.toList + case None => + List() + else + List() + + private def createSymbolSolver( + config: Config, + dependencies: List[String], + sourceParser: SourceParser, + sourceFilenames: Array[String] + ): (JavaSymbolSolver, mutable.Map[String, mutable.Set[String]]) = + val combinedTypeSolver = new SimpleCombinedTypeSolver() + val symbolSolver = new JavaSymbolSolver(combinedTypeSolver) + + val jdkPathFromEnvVar = Option(System.getenv(JavaSrcEnvVar.JdkPath.name)) + val jdkPath = (config.jdkPath, jdkPathFromEnvVar) match + case (None, None) => + val javaHome = System.getProperty("java.home") + if javaHome != null && javaHome.nonEmpty then javaHome + else System.getenv("JAVA_HOME") + case (None, Some(jdkPath)) => + jdkPath + + case (Some(jdkPath), _) => + jdkPath + var jdkJarTypeSolver: JdkJarTypeSolver = null + // native-image could have empty JAVA_HOME + if jdkPath != null && jdkPath.nonEmpty then + jdkJarTypeSolver = JdkJarTypeSolver.fromJdkPath(jdkPath) + combinedTypeSolver.addNonCachingTypeSolver(jdkJarTypeSolver) + val relativeSourceFilenames = + sourceFilenames.map(filename => + Path.of(config.inputPath).relativize(Path.of(filename)).toString + ) + + val sourceTypeSolver = + EagerSourceTypeSolver( + relativeSourceFilenames, + sourceParser, + combinedTypeSolver, + symbolSolver + ) + combinedTypeSolver.addCachingTypeSolver(sourceTypeSolver) + combinedTypeSolver.addNonCachingTypeSolver(new ReflectionTypeSolver()) + // Add solvers for inference jars + val jarsList = config.inferenceJarPaths.flatMap(recursiveJarsFromPath).toList + (jarsList ++ dependencies) + .flatMap { path => + Try(new JarTypeSolver(path)).toOption + } + .foreach { combinedTypeSolver.addNonCachingTypeSolver(_) } + if jdkJarTypeSolver != null then (symbolSolver, jdkJarTypeSolver.packagesJarMappings) + else (symbolSolver, null) + end createSymbolSolver + + private def recursiveJarsFromPath(path: String): List[String] = + Try(File(path)) match + case Success(file) if file.isDirectory => + file.listRecursively + .map(_.canonicalPath) + .filter(_.endsWith(".jar")) + .toList + + case Success(file) if file.canonicalPath.endsWith(".jar") => + List(file.canonicalPath) + + case _ => + Nil end AstCreationPass diff --git a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/passes/AstCreator.scala b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/passes/AstCreator.scala index 6da8d023..8ae8d4ec 100644 --- a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/passes/AstCreator.scala +++ b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/passes/AstCreator.scala @@ -204,18 +204,18 @@ case class PartialConstructor(initNode: NewCall, initArgs: Seq[Ast], blockAst: A case class ExpectedType(fullName: Option[String], resolvedType: Option[ResolvedType] = None) object ExpectedType: - def empty: ExpectedType = ExpectedType(None, None) - val Int: ExpectedType = ExpectedType(Some(TypeConstants.Int)) - val Boolean: ExpectedType = ExpectedType(Some(TypeConstants.Boolean)) - val Void: ExpectedType = ExpectedType(Some(TypeConstants.Void)) + def empty: ExpectedType = ExpectedType(None, None) + val Int: ExpectedType = ExpectedType(Some(TypeConstants.Int)) + val Boolean: ExpectedType = ExpectedType(Some(TypeConstants.Boolean)) + val Void: ExpectedType = ExpectedType(Some(TypeConstants.Void)) case class AstWithStaticInit(ast: Seq[Ast], staticInits: Seq[Ast]) object AstWithStaticInit: - val empty: AstWithStaticInit = AstWithStaticInit(Seq.empty, Seq.empty) + val empty: AstWithStaticInit = AstWithStaticInit(Seq.empty, Seq.empty) - def apply(ast: Ast): AstWithStaticInit = - AstWithStaticInit(Seq(ast), staticInits = Seq.empty) + def apply(ast: Ast): AstWithStaticInit = + AstWithStaticInit(Seq(ast), staticInits = Seq.empty) /** Translate a Java Parser AST into a CPG AST */ @@ -229,3196 +229,3199 @@ class AstCreator( extends AstCreatorBase(filename) with AstNodeBuilder[Node, AstCreator]: - private val logger = LoggerFactory.getLogger(this.getClass) - - private val scope = Scope() - - private val typeInfoCalc: TypeInfoCalculator = TypeInfoCalculator(global, symbolSolver) - private val partialConstructorQueue: mutable.ArrayBuffer[PartialConstructor] = - mutable.ArrayBuffer.empty - private val bindingTableCache = mutable.HashMap.empty[String, BindingTable] - - // TODO: Perhaps move this to a NameProvider or some such? Look at kt2cpg to see if some unified representation - // makes sense. - private val LambdaNamePrefix = "lambda$" - private val lambdaKeyPool = new IntervalKeyPool(first = 0, last = Long.MaxValue) - private val IndexNamePrefix = "$idx" - private val indexKeyPool = new IntervalKeyPool(first = 0, last = Long.MaxValue) - private val IterableNamePrefix = "$iterLocal" - private val iterableKeyPool = new IntervalKeyPool(first = 0, last = Long.MaxValue) - - /** Entry point of AST creation. Translates a compilation unit created by JavaParser into a - * DiffGraph containing the corresponding CPG AST. - */ - def createAst(): DiffGraphBuilder = - val ast = astForTranslationUnit(javaParserAst) - storeInDiffGraph(ast) - diffGraph - - /** Copy nodes/edges of given `AST` into the diff graph - */ - def storeInDiffGraph(ast: Ast): Unit = - Ast.storeInDiffGraph(ast, diffGraph) - - protected def line(node: Node): Option[Integer] = - node.getBegin.map(x => Integer.valueOf(x.line)).toScala - protected def column(node: Node): Option[Integer] = - node.getBegin.map(x => Integer.valueOf(x.column)).toScala - protected def lineEnd(node: Node): Option[Integer] = - node.getEnd.map(x => Integer.valueOf(x.line)).toScala - protected def columnEnd(node: Node): Option[Integer] = - node.getEnd.map(x => Integer.valueOf(x.line)).toScala - protected def code(node: Node): String = "" - - // TODO: Handle static imports correctly. - private def addImportsToScope(compilationUnit: CompilationUnit): Seq[NewImport] = - val (asteriskImports, specificImports) = - compilationUnit.getImports.asScala.toList.partition(_.isAsterisk) - val specificImportNodes = specificImports.map { importStmt => - val name = importStmt.getName.getIdentifier - val typeFullName = importStmt.getNameAsString // fully qualified name - typeInfoCalc.registerType(typeFullName) - val importNode = NewImport() - .importedAs(name) - .importedEntity(typeFullName) - - if importStmt.isStatic then - scope.addStaticImport(importNode) - else - scope.addType(name, typeFullName) - importNode + private val logger = LoggerFactory.getLogger(this.getClass) + + private val scope = Scope() + + private val typeInfoCalc: TypeInfoCalculator = TypeInfoCalculator(global, symbolSolver) + private val partialConstructorQueue: mutable.ArrayBuffer[PartialConstructor] = + mutable.ArrayBuffer.empty + private val bindingTableCache = mutable.HashMap.empty[String, BindingTable] + + // TODO: Perhaps move this to a NameProvider or some such? Look at kt2cpg to see if some unified representation + // makes sense. + private val LambdaNamePrefix = "lambda$" + private val lambdaKeyPool = new IntervalKeyPool(first = 0, last = Long.MaxValue) + private val IndexNamePrefix = "$idx" + private val indexKeyPool = new IntervalKeyPool(first = 0, last = Long.MaxValue) + private val IterableNamePrefix = "$iterLocal" + private val iterableKeyPool = new IntervalKeyPool(first = 0, last = Long.MaxValue) + + /** Entry point of AST creation. Translates a compilation unit created by JavaParser into a + * DiffGraph containing the corresponding CPG AST. + */ + def createAst(): DiffGraphBuilder = + val ast = astForTranslationUnit(javaParserAst) + storeInDiffGraph(ast) + diffGraph + + /** Copy nodes/edges of given `AST` into the diff graph + */ + def storeInDiffGraph(ast: Ast): Unit = + Ast.storeInDiffGraph(ast, diffGraph) + + protected def line(node: Node): Option[Integer] = + node.getBegin.map(x => Integer.valueOf(x.line)).toScala + protected def column(node: Node): Option[Integer] = + node.getBegin.map(x => Integer.valueOf(x.column)).toScala + protected def lineEnd(node: Node): Option[Integer] = + node.getEnd.map(x => Integer.valueOf(x.line)).toScala + protected def columnEnd(node: Node): Option[Integer] = + node.getEnd.map(x => Integer.valueOf(x.line)).toScala + protected def code(node: Node): String = "" + + // TODO: Handle static imports correctly. + private def addImportsToScope(compilationUnit: CompilationUnit): Seq[NewImport] = + val (asteriskImports, specificImports) = + compilationUnit.getImports.asScala.toList.partition(_.isAsterisk) + val specificImportNodes = specificImports.map { importStmt => + val name = importStmt.getName.getIdentifier + val typeFullName = importStmt.getNameAsString // fully qualified name + typeInfoCalc.registerType(typeFullName) + val importNode = NewImport() + .importedAs(name) + .importedEntity(typeFullName) + + if importStmt.isStatic then + scope.addStaticImport(importNode) + else + scope.addType(name, typeFullName) + importNode + } + + val asteriskImportNodes = asteriskImports match + case imp :: Nil => + val name = NameConstants.WildcardImportName + val typeFullName = imp.getNameAsString + val importNode = NewImport() + .importedAs(name) + .importedEntity(typeFullName) + .isWildcard(true) + scope.addWildcardImport(typeFullName) + Seq(importNode) + case _ => // Only try to guess a wildcard import if exactly one is defined + Seq.empty + specificImportNodes ++ asteriskImportNodes + end addImportsToScope + + /** Translate compilation unit into AST + */ + private def astForTranslationUnit(compilationUnit: CompilationUnit): Ast = + try + val namespaceBlock = + namespaceBlockForPackageDecl(compilationUnit.getPackageDeclaration.toScala) + + scope.pushNamespaceScope(namespaceBlock) + + val importNodes = addImportsToScope(compilationUnit).map(Ast(_)) + + val typeDeclAsts = compilationUnit.getTypes.asScala.map { typ => + astForTypeDecl( + typ, + astParentType = NodeTypes.NAMESPACE_BLOCK, + astParentFullName = namespaceBlock.fullName + ) } - val asteriskImportNodes = asteriskImports match - case imp :: Nil => - val name = NameConstants.WildcardImportName - val typeFullName = imp.getNameAsString - val importNode = NewImport() - .importedAs(name) - .importedEntity(typeFullName) - .isWildcard(true) - scope.addWildcardImport(typeFullName) - Seq(importNode) - case _ => // Only try to guess a wildcard import if exactly one is defined - Seq.empty - specificImportNodes ++ asteriskImportNodes - end addImportsToScope - - /** Translate compilation unit into AST - */ - private def astForTranslationUnit(compilationUnit: CompilationUnit): Ast = - try - val namespaceBlock = - namespaceBlockForPackageDecl(compilationUnit.getPackageDeclaration.toScala) - - scope.pushNamespaceScope(namespaceBlock) - - val importNodes = addImportsToScope(compilationUnit).map(Ast(_)) - - val typeDeclAsts = compilationUnit.getTypes.asScala.map { typ => - astForTypeDecl( - typ, - astParentType = NodeTypes.NAMESPACE_BLOCK, - astParentFullName = namespaceBlock.fullName - ) + // TODO: Add ASTs + scope.popScope() + Ast(namespaceBlock).withChildren(typeDeclAsts).withChildren(importNodes) + catch + case t: UnsolvedSymbolException => + logger.error(s"Unsolved symbol exception caught in $filename") + Ast() + case t: Throwable => + logger.debug(s"Parsing file $filename failed with $t") + logger.debug(s"Caused by ${t.getCause}") + Ast() + + /** Translate package declaration into AST consisting of a corresponding namespace block. + */ + private def namespaceBlockForPackageDecl(packageDecl: Option[PackageDeclaration]) + : NewNamespaceBlock = + packageDecl match + case Some(decl) => + val packageName = decl.getName.toString + val fullName = s"$filename:$packageName" + NewNamespaceBlock() + .name(packageName) + .fullName(fullName) + .filename(filename) + case None => + globalNamespaceBlock() + + private def tryWithSafeStackOverflow[T](expr: => T): Try[T] = + try + Try(expr) + catch + // This is really, really ugly, but there's a bug in the JavaParser symbol solver that can lead to + // unterminated recursion in some cases where types cannot be resolved. + // Update: This must be fixed with https://github.com/javaparser/javaparser/pull/4236 + case e: StackOverflowError => + logger.debug(s"Caught StackOverflowError in $filename") + Failure(e) + + private def composeSignature( + maybeReturnType: Option[String], + maybeParameterTypes: Option[List[String]], + parameterCount: Int + ): String = + (maybeReturnType, maybeParameterTypes) match + case (Some(returnType), Some(parameterTypes)) => + composeMethodLikeSignature(returnType, parameterTypes) + + case _ => + composeUnresolvedSignature(parameterCount) + + private def methodSignature( + method: ResolvedMethodDeclaration, + typeParamValues: ResolvedTypeParametersMap + ): String = + val maybeParameterTypes = calcParameterTypes(method, typeParamValues) + + val maybeReturnType = + Try(method.getReturnType).toOption + .flatMap(returnType => typeInfoCalc.fullName(returnType, typeParamValues)) + + composeSignature(maybeReturnType, maybeParameterTypes, method.getNumberOfParams) + + private def toOptionList[T](items: collection.Seq[Option[T]]): Option[List[T]] = + items.foldLeft[Option[List[T]]](Some(Nil)) { + case (Some(acc), Some(value)) => Some(acc :+ value) + case _ => None + } + + private def calcParameterTypes( + methodLike: ResolvedMethodLikeDeclaration, + typeParamValues: ResolvedTypeParametersMap + ): Option[List[String]] = + val parameterTypes = + Range(0, methodLike.getNumberOfParams) + .flatMap { index => + Try(methodLike.getParam(index)).toOption + } + .map { param => + Try(param.getType).toOption + .flatMap(paramType => typeInfoCalc.fullName(paramType, typeParamValues)) } - // TODO: Add ASTs - scope.popScope() - Ast(namespaceBlock).withChildren(typeDeclAsts).withChildren(importNodes) - catch - case t: UnsolvedSymbolException => - logger.error(s"Unsolved symbol exception caught in $filename") - Ast() - case t: Throwable => - logger.debug(s"Parsing file $filename failed with $t") - logger.debug(s"Caused by ${t.getCause}") - Ast() - - /** Translate package declaration into AST consisting of a corresponding namespace block. - */ - private def namespaceBlockForPackageDecl(packageDecl: Option[PackageDeclaration]) - : NewNamespaceBlock = - packageDecl match - case Some(decl) => - val packageName = decl.getName.toString - val fullName = s"$filename:$packageName" - NewNamespaceBlock() - .name(packageName) - .fullName(fullName) - .filename(filename) - case None => - globalNamespaceBlock() - - private def tryWithSafeStackOverflow[T](expr: => T): Try[T] = - try - Try(expr) - catch - // This is really, really ugly, but there's a bug in the JavaParser symbol solver that can lead to - // unterminated recursion in some cases where types cannot be resolved. - // Update: This must be fixed with https://github.com/javaparser/javaparser/pull/4236 - case e: StackOverflowError => - logger.debug(s"Caught StackOverflowError in $filename") - Failure(e) - - private def composeSignature( - maybeReturnType: Option[String], - maybeParameterTypes: Option[List[String]], - parameterCount: Int - ): String = - (maybeReturnType, maybeParameterTypes) match - case (Some(returnType), Some(parameterTypes)) => - composeMethodLikeSignature(returnType, parameterTypes) - - case _ => - composeUnresolvedSignature(parameterCount) - - private def methodSignature( - method: ResolvedMethodDeclaration, - typeParamValues: ResolvedTypeParametersMap - ): String = - val maybeParameterTypes = calcParameterTypes(method, typeParamValues) - - val maybeReturnType = - Try(method.getReturnType).toOption - .flatMap(returnType => typeInfoCalc.fullName(returnType, typeParamValues)) - - composeSignature(maybeReturnType, maybeParameterTypes, method.getNumberOfParams) - - private def toOptionList[T](items: collection.Seq[Option[T]]): Option[List[T]] = - items.foldLeft[Option[List[T]]](Some(Nil)) { - case (Some(acc), Some(value)) => Some(acc :+ value) - case _ => None - } - - private def calcParameterTypes( - methodLike: ResolvedMethodLikeDeclaration, - typeParamValues: ResolvedTypeParametersMap - ): Option[List[String]] = - val parameterTypes = - Range(0, methodLike.getNumberOfParams) - .flatMap { index => - Try(methodLike.getParam(index)).toOption - } - .map { param => - Try(param.getType).toOption - .flatMap(paramType => typeInfoCalc.fullName(paramType, typeParamValues)) - } - - toOptionList(parameterTypes) + toOptionList(parameterTypes) - def getBindingTable(typeDecl: ResolvedReferenceTypeDeclaration): BindingTable = - val fullName = typeInfoCalc.fullName(typeDecl).getOrElse { - val qualifiedName = typeDecl.getQualifiedName - logger.debug( - s"Could not get full name for resolved type decl $qualifiedName. THIS SHOULD NOT HAPPEN!" - ) - qualifiedName - } - bindingTableCache.getOrElseUpdate( - fullName, - createBindingTable( - fullName, - typeDecl, - getBindingTable, - new BindingTableAdapterForJavaparser(methodSignature) - ) + def getBindingTable(typeDecl: ResolvedReferenceTypeDeclaration): BindingTable = + val fullName = typeInfoCalc.fullName(typeDecl).getOrElse { + val qualifiedName = typeDecl.getQualifiedName + logger.debug( + s"Could not get full name for resolved type decl $qualifiedName. THIS SHOULD NOT HAPPEN!" ) - - private def getLambdaBindingTable(lambdaBindingInfo: LambdaBindingInfo): BindingTable = - val fullName = lambdaBindingInfo.fullName - - bindingTableCache.getOrElseUpdate( - fullName, - createBindingTable( - fullName, - lambdaBindingInfo, - getBindingTable, - new BindingTableAdapterForLambdas(methodSignature) - ) + qualifiedName + } + bindingTableCache.getOrElseUpdate( + fullName, + createBindingTable( + fullName, + typeDecl, + getBindingTable, + new BindingTableAdapterForJavaparser(methodSignature) + ) + ) + + private def getLambdaBindingTable(lambdaBindingInfo: LambdaBindingInfo): BindingTable = + val fullName = lambdaBindingInfo.fullName + + bindingTableCache.getOrElseUpdate( + fullName, + createBindingTable( + fullName, + lambdaBindingInfo, + getBindingTable, + new BindingTableAdapterForLambdas(methodSignature) + ) + ) + + private def createBindingNodes(typeDeclNode: NewTypeDecl, bindingTable: BindingTable): Unit = + // We only sort to get stable output. + val sortedEntries = + bindingTable.getEntries.toBuffer.sortBy((entry: BindingTableEntry) => + s"${entry.name}${entry.signature}" ) - private def createBindingNodes(typeDeclNode: NewTypeDecl, bindingTable: BindingTable): Unit = - // We only sort to get stable output. - val sortedEntries = - bindingTable.getEntries.toBuffer.sortBy((entry: BindingTableEntry) => - s"${entry.name}${entry.signature}" - ) - - sortedEntries.foreach { entry => - val bindingNode = - newBindingNode(entry.name, entry.signature, entry.implementingMethodFullName) - - diffGraph.addNode(bindingNode) - diffGraph.addEdge(typeDeclNode, bindingNode, EdgeTypes.BINDS) - } - - private def astForTypeDeclMember( - member: BodyDeclaration[?], - astParentFullName: String - ): AstWithStaticInit = - member match - case constructor: ConstructorDeclaration => - val ast = astForConstructor(constructor) - - AstWithStaticInit(ast) - - case method: MethodDeclaration => - val ast = astForMethod(method) - - AstWithStaticInit(ast) - - case typeDeclaration: TypeDeclaration[?] => - AstWithStaticInit(astForTypeDecl( - typeDeclaration, - NodeTypes.TYPE_DECL, - astParentFullName - )) - - case fieldDeclaration: FieldDeclaration => - val memberAsts = fieldDeclaration.getVariables.asScala.toList.map { variable => - astForFieldVariable(variable, fieldDeclaration) - } - - val assignments = assignmentsForVarDecl( - fieldDeclaration.getVariables.asScala.toList, - line(fieldDeclaration), - column(fieldDeclaration) - ) + sortedEntries.foreach { entry => + val bindingNode = + newBindingNode(entry.name, entry.signature, entry.implementingMethodFullName) - val staticInitAsts = if fieldDeclaration.isStatic then assignments else Nil - if !fieldDeclaration.isStatic then scope.addMemberInitializers(assignments) + diffGraph.addNode(bindingNode) + diffGraph.addEdge(typeDeclNode, bindingNode, EdgeTypes.BINDS) + } - AstWithStaticInit(memberAsts, staticInitAsts) + private def astForTypeDeclMember( + member: BodyDeclaration[?], + astParentFullName: String + ): AstWithStaticInit = + member match + case constructor: ConstructorDeclaration => + val ast = astForConstructor(constructor) - case initDeclaration: InitializerDeclaration => - val stmts = initDeclaration.getBody.getStatements - val asts = stmts.asScala.flatMap(astsForStatement).toList - AstWithStaticInit(ast = Seq.empty, staticInits = asts) + AstWithStaticInit(ast) - case unhandled => - // AnnotationMemberDeclarations and InitializerDeclarations as children of typeDecls are the - // expected cases. - logger.debug( - s"Found unhandled typeDecl member ${unhandled.getClass} in file $filename" - ) - AstWithStaticInit.empty + case method: MethodDeclaration => + val ast = astForMethod(method) - private def identifierForResolvedTypeParameter(typeParameter: ResolvedTypeParameterDeclaration) - : NewIdentifier = - val name = typeParameter.getName - val typeFullName = Try(typeParameter.getUpperBound).toOption - .flatMap(typeInfoCalc.fullName) - .getOrElse(TypeConstants.Object) - typeInfoCalc.registerType(typeFullName) - newIdentifierNode(name, typeFullName) - - private def clinitAstFromStaticInits(staticInits: Seq[Ast]): Option[Ast] = - Option.when(staticInits.nonEmpty) { - val signature = composeMethodLikeSignature(TypeConstants.Void, Nil) - val enclosingDeclName = - scope.enclosingTypeDeclFullName.getOrElse(Defines.UnresolvedNamespace) - val fullName = - composeMethodFullName(enclosingDeclName, Defines.StaticInitMethodName, signature) - staticInitMethodAst(staticInits.toList, fullName, Some(signature), TypeConstants.Void) - } + AstWithStaticInit(ast) - private def codeForTypeDecl(typ: TypeDeclaration[?], isInterface: Boolean): String = - val codeBuilder = new mutable.StringBuilder() - if typ.isPublic then - codeBuilder.append("public ") - else if typ.isPrivate then - codeBuilder.append("private ") - else if typ.isProtected then - codeBuilder.append("protected ") - - if typ.isStatic then - codeBuilder.append("static ") - - val classPrefix = - if isInterface then - "interface " - else if typ.isEnumDeclaration then - "enum " - else - "class " - codeBuilder.append(classPrefix) - codeBuilder.append(typ.getNameAsString) - - codeBuilder.toString() - end codeForTypeDecl - - private def modifiersForTypeDecl( - typ: TypeDeclaration[?], - isInterface: Boolean - ): List[NewModifier] = - val accessModifierType = if typ.isPublic then - Some(ModifierTypes.PUBLIC) - else if typ.isPrivate then - Some(ModifierTypes.PRIVATE) - else if typ.isProtected then - Some(ModifierTypes.PROTECTED) - else - None - val accessModifier = accessModifierType.map(newModifierNode) - - val abstractModifier = - Option.when(isInterface || typ.getMethods.asScala.exists(_.isAbstract))(newModifierNode( - ModifierTypes.ABSTRACT + case typeDeclaration: TypeDeclaration[?] => + AstWithStaticInit(astForTypeDecl( + typeDeclaration, + NodeTypes.TYPE_DECL, + astParentFullName )) - List(accessModifier, abstractModifier).flatten - end modifiersForTypeDecl - - private def createTypeDeclNode( - typ: TypeDeclaration[?], - astParentType: String, - astParentFullName: String, - isInterface: Boolean - ): NewTypeDecl = - val baseTypeFullNames = if typ.isClassOrInterfaceDeclaration then - val decl = typ.asClassOrInterfaceDeclaration() - val extendedTypes = decl.getExtendedTypes.asScala - val implementedTypes = decl.getImplementedTypes.asScala - val inheritsFromTypeNames = - (extendedTypes ++ implementedTypes).flatMap { typ => - typeInfoCalc.fullName(typ).orElse(scope.lookupType(typ.getNameAsString)) - } - val maybeJavaObjectType = if extendedTypes.isEmpty then - typeInfoCalc.registerType(TypeConstants.Object) - Seq(TypeConstants.Object) - else - Seq() - maybeJavaObjectType ++ inheritsFromTypeNames - else - List.empty[String] - - val resolvedType = tryWithSafeStackOverflow(typ.resolve()).toOption - val defaultFullName = s"${Defines.UnresolvedNamespace}.${typ.getNameAsString}" - val name = resolvedType.flatMap(typeInfoCalc.name).getOrElse(typ.getNameAsString) - val typeFullName = resolvedType.flatMap(typeInfoCalc.fullName).getOrElse(defaultFullName) - val code = codeForTypeDecl(typ, isInterface) - val typeDecl = NewTypeDecl() - .name(name) - .fullName(typeFullName) - .lineNumber(line(typ)) - .columnNumber(column(typ)) - .inheritsFromTypeFullName(baseTypeFullNames) - .filename(filename) - .code(code) - .astParentType(astParentType) - .astParentFullName(astParentFullName) - if packagesJarMappings.contains(typeFullName) then - typeDecl.aliasTypeFullName(packagesJarMappings.getOrElse( - typeFullName, - mutable.Set.empty - ).headOption) - typeDecl - end createTypeDeclNode - - private def addTypeDeclTypeParamsToScope(typ: TypeDeclaration[?]): Unit = - tryWithSafeStackOverflow(typ.resolve()).map(_.getTypeParameters.asScala) match - case Success(resolvedTypeParams) => - resolvedTypeParams - .map(identifierForResolvedTypeParameter) - .foreach { typeParamIdentifier => - scope.addType(typeParamIdentifier.name, typeParamIdentifier.typeFullName) - } - - case _ => // Nothing to do here - private def astForTypeDecl( - typ: TypeDeclaration[?], - astParentType: String, - astParentFullName: String - ): Ast = - val isInterface = typ match - case classDeclaration: ClassOrInterfaceDeclaration => classDeclaration.isInterface - case _ => false - - val typeDeclNode = createTypeDeclNode(typ, astParentType, astParentFullName, isInterface) - - scope.pushTypeDeclScope(typeDeclNode) - addTypeDeclTypeParamsToScope(typ) - - val enumEntryAsts = if typ.isEnumDeclaration then - typ.asEnumDeclaration().getEntries.asScala.map(astForEnumEntry).toList - else - List.empty - - val staticInits: mutable.Buffer[Ast] = mutable.Buffer() - val memberAsts = typ.getMembers.asScala.flatMap { member => - val astWithInits = - astForTypeDeclMember(member, astParentFullName = NodeTypes.TYPE_DECL) - staticInits.appendAll(astWithInits.staticInits) - astWithInits.ast - } - - val defaultConstructorAst = if !isInterface && typ.getConstructors.isEmpty then - Some(astForDefaultConstructor()) - else - None - - val annotationAsts = typ.getAnnotations.asScala.map(astForAnnotationExpr) - - val clinitAst = clinitAstFromStaticInits(staticInits.toSeq) - - val localDecls = scope.localDeclsInScope - val lambdaMethods = scope.lambdaMethodsInScope - - val modifiers = modifiersForTypeDecl(typ, isInterface) - - val typeDeclAst = Ast(typeDeclNode) - .withChildren(enumEntryAsts) - .withChildren(memberAsts) - .withChildren(defaultConstructorAst.toList) - .withChildren(annotationAsts) - .withChildren(clinitAst.toSeq) - .withChildren(localDecls) - .withChildren(lambdaMethods) - .withChildren(modifiers.map(Ast(_))) - - val defaultConstructorBindingEntry = - defaultConstructorAst - .flatMap(_.root) - .collect { case defaultConstructor: NewMethod => - BindingTableEntry( - io.appthreat.x2cpg.Defines.ConstructorMethodName, - defaultConstructor.signature, - defaultConstructor.fullName - ) - } - - // Annotation declarations need no binding table as objects of this - // typ never get called from user code. - // Furthermore the parser library throws an exception when trying to - // access e.g. the declared methods of an annotation declaration. - if !typ.isInstanceOf[AnnotationDeclaration] then - tryWithSafeStackOverflow(typ.resolve()).toOption.foreach { resolvedTypeDecl => - val bindingTable = getBindingTable(resolvedTypeDecl) - defaultConstructorBindingEntry.foreach(bindingTable.add) - createBindingNodes(typeDeclNode, bindingTable) + case fieldDeclaration: FieldDeclaration => + val memberAsts = fieldDeclaration.getVariables.asScala.toList.map { variable => + astForFieldVariable(variable, fieldDeclaration) } - scope.popScope() - - typeDeclAst - end astForTypeDecl - - private def astForDefaultConstructor(): Ast = - val typeFullName = scope.enclosingTypeDeclFullName - val signature = s"${TypeConstants.Void}()" - val fullName = composeMethodFullName( - typeFullName.getOrElse(Defines.UnresolvedNamespace), - Defines.ConstructorMethodName, - signature - ) - val constructorNode = NewMethod() - .name(io.appthreat.x2cpg.Defines.ConstructorMethodName) - .fullName(fullName) - .signature(signature) - .filename(filename) - .isExternal(false) - - val thisAst = Ast(thisNodeForMethod(typeFullName, lineNumber = None)) - val bodyAst = Ast(NewBlock()).withChildren(scope.memberInitializers) - - val returnNode = newMethodReturnNode(TypeConstants.Void, line = None, column = None) - - val modifiers = - List(newModifierNode(ModifierTypes.CONSTRUCTOR), newModifierNode(ModifierTypes.PUBLIC)) - - methodAstWithAnnotations(constructorNode, Seq(thisAst), bodyAst, returnNode, modifiers) - end astForDefaultConstructor - - private def astForEnumEntry(entry: EnumConstantDeclaration): Ast = - // TODO Fix enum entries in general - val typeFullName = - tryWithSafeStackOverflow(entry.resolve().getType).toOption.flatMap( - typeInfoCalc.fullName + val assignments = assignmentsForVarDecl( + fieldDeclaration.getVariables.asScala.toList, + line(fieldDeclaration), + column(fieldDeclaration) ) - val entryNode = - memberNode(entry, entry.getNameAsString, entry.toString, typeFullName.getOrElse("ANY")) - - val name = - s"${typeFullName.getOrElse(Defines.UnresolvedNamespace)}.${Defines.ConstructorMethodName}" - - Ast(entryNode) - - private def modifiersForFieldDeclaration(decl: FieldDeclaration): Seq[Ast] = - val staticModifier = - Option.when(decl.isStatic)(newModifierNode(ModifierTypes.STATIC)) - - val accessModifierType = - if decl.isPublic then - Some(ModifierTypes.PUBLIC) - else if decl.isPrivate then - Some(ModifierTypes.PRIVATE) - else if decl.isProtected then - Some(ModifierTypes.PROTECTED) - else - None + val staticInitAsts = if fieldDeclaration.isStatic then assignments else Nil + if !fieldDeclaration.isStatic then scope.addMemberInitializers(assignments) - val accessModifier = accessModifierType.map(newModifierNode) + AstWithStaticInit(memberAsts, staticInitAsts) - List(staticModifier, accessModifier).flatten.map(Ast(_)) + case initDeclaration: InitializerDeclaration => + val stmts = initDeclaration.getBody.getStatements + val asts = stmts.asScala.flatMap(astsForStatement).toList + AstWithStaticInit(ast = Seq.empty, staticInits = asts) - private def astForFieldVariable( - v: VariableDeclarator, - fieldDeclaration: FieldDeclaration - ): Ast = - // TODO: Should be able to find expected type here - val annotations = fieldDeclaration.getAnnotations + case unhandled => + // AnnotationMemberDeclarations and InitializerDeclarations as children of typeDecls are the + // expected cases. + logger.debug( + s"Found unhandled typeDecl member ${unhandled.getClass} in file $filename" + ) + AstWithStaticInit.empty + + private def identifierForResolvedTypeParameter(typeParameter: ResolvedTypeParameterDeclaration) + : NewIdentifier = + val name = typeParameter.getName + val typeFullName = Try(typeParameter.getUpperBound).toOption + .flatMap(typeInfoCalc.fullName) + .getOrElse(TypeConstants.Object) + typeInfoCalc.registerType(typeFullName) + newIdentifierNode(name, typeFullName) + + private def clinitAstFromStaticInits(staticInits: Seq[Ast]): Option[Ast] = + Option.when(staticInits.nonEmpty) { + val signature = composeMethodLikeSignature(TypeConstants.Void, Nil) + val enclosingDeclName = + scope.enclosingTypeDeclFullName.getOrElse(Defines.UnresolvedNamespace) + val fullName = + composeMethodFullName(enclosingDeclName, Defines.StaticInitMethodName, signature) + staticInitMethodAst(staticInits.toList, fullName, Some(signature), TypeConstants.Void) + } + + private def codeForTypeDecl(typ: TypeDeclaration[?], isInterface: Boolean): String = + val codeBuilder = new mutable.StringBuilder() + if typ.isPublic then + codeBuilder.append("public ") + else if typ.isPrivate then + codeBuilder.append("private ") + else if typ.isProtected then + codeBuilder.append("protected ") + + if typ.isStatic then + codeBuilder.append("static ") + + val classPrefix = + if isInterface then + "interface " + else if typ.isEnumDeclaration then + "enum " + else + "class " + codeBuilder.append(classPrefix) + codeBuilder.append(typ.getNameAsString) + + codeBuilder.toString() + end codeForTypeDecl + + private def modifiersForTypeDecl( + typ: TypeDeclaration[?], + isInterface: Boolean + ): List[NewModifier] = + val accessModifierType = if typ.isPublic then + Some(ModifierTypes.PUBLIC) + else if typ.isPrivate then + Some(ModifierTypes.PRIVATE) + else if typ.isProtected then + Some(ModifierTypes.PROTECTED) + else + None + val accessModifier = accessModifierType.map(newModifierNode) + + val abstractModifier = + Option.when(isInterface || typ.getMethods.asScala.exists(_.isAbstract))(newModifierNode( + ModifierTypes.ABSTRACT + )) - // variable can be declared with generic type, so we need to get rid of the <> part of it to get the package information - // and append the <> when forming the typeFullName again - // Ex - private Consumer consumer; - // From Consumer we need to get to Consumer so splitting it by '<' and then combining with '<' to - // form typeFullName as Consumer + List(accessModifier, abstractModifier).flatten + end modifiersForTypeDecl + + private def createTypeDeclNode( + typ: TypeDeclaration[?], + astParentType: String, + astParentFullName: String, + isInterface: Boolean + ): NewTypeDecl = + val baseTypeFullNames = if typ.isClassOrInterfaceDeclaration then + val decl = typ.asClassOrInterfaceDeclaration() + val extendedTypes = decl.getExtendedTypes.asScala + val implementedTypes = decl.getImplementedTypes.asScala + val inheritsFromTypeNames = + (extendedTypes ++ implementedTypes).flatMap { typ => + typeInfoCalc.fullName(typ).orElse(scope.lookupType(typ.getNameAsString)) + } + val maybeJavaObjectType = if extendedTypes.isEmpty then + typeInfoCalc.registerType(TypeConstants.Object) + Seq(TypeConstants.Object) + else + Seq() + maybeJavaObjectType ++ inheritsFromTypeNames + else + List.empty[String] + + val resolvedType = tryWithSafeStackOverflow(typ.resolve()).toOption + val defaultFullName = s"${Defines.UnresolvedNamespace}.${typ.getNameAsString}" + val name = resolvedType.flatMap(typeInfoCalc.name).getOrElse(typ.getNameAsString) + val typeFullName = resolvedType.flatMap(typeInfoCalc.fullName).getOrElse(defaultFullName) + val code = codeForTypeDecl(typ, isInterface) + val typeDecl = NewTypeDecl() + .name(name) + .fullName(typeFullName) + .lineNumber(line(typ)) + .columnNumber(column(typ)) + .inheritsFromTypeFullName(baseTypeFullNames) + .filename(filename) + .code(code) + .astParentType(astParentType) + .astParentFullName(astParentFullName) + if packagesJarMappings.contains(typeFullName) then + typeDecl.aliasTypeFullName(packagesJarMappings.getOrElse( + typeFullName, + mutable.Set.empty + ).headOption) + typeDecl + end createTypeDeclNode + + private def addTypeDeclTypeParamsToScope(typ: TypeDeclaration[?]): Unit = + tryWithSafeStackOverflow(typ.resolve()).map(_.getTypeParameters.asScala) match + case Success(resolvedTypeParams) => + resolvedTypeParams + .map(identifierForResolvedTypeParameter) + .foreach { typeParamIdentifier => + scope.addType(typeParamIdentifier.name, typeParamIdentifier.typeFullName) + } - val typeFullNameWithoutGenericSplit = typeInfoCalc - .fullName(v.getType) - .orElse(scope.lookupType(v.getTypeAsString)) - .getOrElse(guessTypeFullName(v.getTypeAsString)) - val typeFullName = - // Check if the typeFullName is unresolved and if it has generic information to resolve the typeFullName - if - typeFullNameWithoutGenericSplit - .contains(Defines.UnresolvedNamespace) && v.getTypeAsString.contains( - Defines.LeftAngularBracket + case _ => // Nothing to do here + private def astForTypeDecl( + typ: TypeDeclaration[?], + astParentType: String, + astParentFullName: String + ): Ast = + val isInterface = typ match + case classDeclaration: ClassOrInterfaceDeclaration => classDeclaration.isInterface + case _ => false + + val typeDeclNode = createTypeDeclNode(typ, astParentType, astParentFullName, isInterface) + + scope.pushTypeDeclScope(typeDeclNode) + addTypeDeclTypeParamsToScope(typ) + + val enumEntryAsts = if typ.isEnumDeclaration then + typ.asEnumDeclaration().getEntries.asScala.map(astForEnumEntry).toList + else + List.empty + + val staticInits: mutable.Buffer[Ast] = mutable.Buffer() + val memberAsts = typ.getMembers.asScala.flatMap { member => + val astWithInits = + astForTypeDeclMember(member, astParentFullName = NodeTypes.TYPE_DECL) + staticInits.appendAll(astWithInits.staticInits) + astWithInits.ast + } + + val defaultConstructorAst = if !isInterface && typ.getConstructors.isEmpty then + Some(astForDefaultConstructor()) + else + None + + val annotationAsts = typ.getAnnotations.asScala.map(astForAnnotationExpr) + + val clinitAst = clinitAstFromStaticInits(staticInits.toSeq) + + val localDecls = scope.localDeclsInScope + val lambdaMethods = scope.lambdaMethodsInScope + + val modifiers = modifiersForTypeDecl(typ, isInterface) + + val typeDeclAst = Ast(typeDeclNode) + .withChildren(enumEntryAsts) + .withChildren(memberAsts) + .withChildren(defaultConstructorAst.toList) + .withChildren(annotationAsts) + .withChildren(clinitAst.toSeq) + .withChildren(localDecls) + .withChildren(lambdaMethods) + .withChildren(modifiers.map(Ast(_))) + + val defaultConstructorBindingEntry = + defaultConstructorAst + .flatMap(_.root) + .collect { case defaultConstructor: NewMethod => + BindingTableEntry( + io.appthreat.x2cpg.Defines.ConstructorMethodName, + defaultConstructor.signature, + defaultConstructor.fullName ) - then - val splitByLeftAngular = v.getTypeAsString.split(Defines.LeftAngularBracket) - scope.lookupType(splitByLeftAngular.head) match - case Some(fullName) => - fullName + splitByLeftAngular - .slice(1, splitByLeftAngular.size) - .mkString(Defines.LeftAngularBracket, Defines.LeftAngularBracket, "") - case None => typeFullNameWithoutGenericSplit - else typeFullNameWithoutGenericSplit - val name = v.getName.toString - val node = memberNode(v, name, s"$typeFullName $name", typeFullName) - val memberAst = Ast(node) - val annotationAsts = annotations.asScala.map(astForAnnotationExpr) - - val fieldDeclModifiers = modifiersForFieldDeclaration(fieldDeclaration) - - scope.addMember(node, fieldDeclaration.isStatic) - - memberAst - .withChildren(annotationAsts) - .withChildren(fieldDeclModifiers) - end astForFieldVariable - - private def astForConstructor(constructorDeclaration: ConstructorDeclaration): Ast = - val constructorNode = createPartialMethod(constructorDeclaration) - .name(io.appthreat.x2cpg.Defines.ConstructorMethodName) - - scope.pushMethodScope(constructorNode, ExpectedType.Void) - val maybeResolved = tryWithSafeStackOverflow(constructorDeclaration.resolve()) - - val parameterAsts = astsForParameterList(constructorDeclaration.getParameters).toList - val paramTypes = argumentTypesForMethodLike(maybeResolved) - val signature = composeSignature(Some(TypeConstants.Void), paramTypes, parameterAsts.size) - val typeFullName = scope.enclosingTypeDeclFullName - val fullName = - composeMethodFullName( - typeFullName.getOrElse(Defines.UnresolvedNamespace), - Defines.ConstructorMethodName, - signature - ) - val typeNameLookup = fullName.takeWhile(_ != ':').split("\\.").dropRight(1).mkString(".") - constructorNode - .fullName(fullName) - .signature(signature) - if packagesJarMappings.contains(typeNameLookup) then - constructorNode.astParentType(packagesJarMappings.getOrElse( - typeNameLookup, - mutable.Set.empty - ).head) - - parameterAsts.foreach { ast => - ast.root match - case Some(parameter: NewMethodParameterIn) => scope.addParameter(parameter) - case _ => // This should never happen - } + } - val thisNode = thisNodeForMethod(typeFullName, line(constructorDeclaration)) - scope.addParameter(thisNode) - val thisAst = Ast(thisNode) + // Annotation declarations need no binding table as objects of this + // typ never get called from user code. + // Furthermore the parser library throws an exception when trying to + // access e.g. the declared methods of an annotation declaration. + if !typ.isInstanceOf[AnnotationDeclaration] then + tryWithSafeStackOverflow(typ.resolve()).toOption.foreach { resolvedTypeDecl => + val bindingTable = getBindingTable(resolvedTypeDecl) + defaultConstructorBindingEntry.foreach(bindingTable.add) + createBindingNodes(typeDeclNode, bindingTable) + } + + scope.popScope() + + typeDeclAst + end astForTypeDecl + + private def astForDefaultConstructor(): Ast = + val typeFullName = scope.enclosingTypeDeclFullName + val signature = s"${TypeConstants.Void}()" + val fullName = composeMethodFullName( + typeFullName.getOrElse(Defines.UnresolvedNamespace), + Defines.ConstructorMethodName, + signature + ) + val constructorNode = NewMethod() + .name(io.appthreat.x2cpg.Defines.ConstructorMethodName) + .fullName(fullName) + .signature(signature) + .filename(filename) + .isExternal(false) + + val thisAst = Ast(thisNodeForMethod(typeFullName, lineNumber = None)) + val bodyAst = Ast(NewBlock()).withChildren(scope.memberInitializers) + + val returnNode = newMethodReturnNode(TypeConstants.Void, line = None, column = None) + + val modifiers = + List(newModifierNode(ModifierTypes.CONSTRUCTOR), newModifierNode(ModifierTypes.PUBLIC)) + + methodAstWithAnnotations(constructorNode, Seq(thisAst), bodyAst, returnNode, modifiers) + end astForDefaultConstructor + + private def astForEnumEntry(entry: EnumConstantDeclaration): Ast = + // TODO Fix enum entries in general + val typeFullName = + tryWithSafeStackOverflow(entry.resolve().getType).toOption.flatMap( + typeInfoCalc.fullName + ) - val bodyAst = astForConstructorBody(Some(constructorDeclaration.getBody)) - val methodReturn = constructorReturnNode(constructorDeclaration) + val entryNode = + memberNode(entry, entry.getNameAsString, entry.toString, typeFullName.getOrElse("ANY")) - val annotationAsts = - constructorDeclaration.getAnnotations.asScala.map(astForAnnotationExpr).toList + val name = + s"${typeFullName.getOrElse(Defines.UnresolvedNamespace)}.${Defines.ConstructorMethodName}" - val modifiers = - NewModifier().modifierType(ModifierTypes.CONSTRUCTOR) :: modifiersForMethod( - constructorDeclaration - ).filterNot( - _.modifierType == ModifierTypes.VIRTUAL - ) + Ast(entryNode) - scope.popScope() + private def modifiersForFieldDeclaration(decl: FieldDeclaration): Seq[Ast] = + val staticModifier = + Option.when(decl.isStatic)(newModifierNode(ModifierTypes.STATIC)) - methodAstWithAnnotations( - constructorNode, - thisAst :: parameterAsts, - bodyAst, - methodReturn, - modifiers, - annotationAsts + val accessModifierType = + if decl.isPublic then + Some(ModifierTypes.PUBLIC) + else if decl.isPrivate then + Some(ModifierTypes.PRIVATE) + else if decl.isProtected then + Some(ModifierTypes.PROTECTED) + else + None + + val accessModifier = accessModifierType.map(newModifierNode) + + List(staticModifier, accessModifier).flatten.map(Ast(_)) + + private def astForFieldVariable( + v: VariableDeclarator, + fieldDeclaration: FieldDeclaration + ): Ast = + // TODO: Should be able to find expected type here + val annotations = fieldDeclaration.getAnnotations + + // variable can be declared with generic type, so we need to get rid of the <> part of it to get the package information + // and append the <> when forming the typeFullName again + // Ex - private Consumer consumer; + // From Consumer we need to get to Consumer so splitting it by '<' and then combining with '<' to + // form typeFullName as Consumer + + val typeFullNameWithoutGenericSplit = typeInfoCalc + .fullName(v.getType) + .orElse(scope.lookupType(v.getTypeAsString)) + .getOrElse(guessTypeFullName(v.getTypeAsString)) + val typeFullName = + // Check if the typeFullName is unresolved and if it has generic information to resolve the typeFullName + if + typeFullNameWithoutGenericSplit + .contains(Defines.UnresolvedNamespace) && v.getTypeAsString.contains( + Defines.LeftAngularBracket + ) + then + val splitByLeftAngular = v.getTypeAsString.split(Defines.LeftAngularBracket) + scope.lookupType(splitByLeftAngular.head) match + case Some(fullName) => + fullName + splitByLeftAngular + .slice(1, splitByLeftAngular.size) + .mkString(Defines.LeftAngularBracket, Defines.LeftAngularBracket, "") + case None => typeFullNameWithoutGenericSplit + else typeFullNameWithoutGenericSplit + val name = v.getName.toString + val node = memberNode(v, name, s"$typeFullName $name", typeFullName) + val memberAst = Ast(node) + val annotationAsts = annotations.asScala.map(astForAnnotationExpr) + + val fieldDeclModifiers = modifiersForFieldDeclaration(fieldDeclaration) + + scope.addMember(node, fieldDeclaration.isStatic) + + memberAst + .withChildren(annotationAsts) + .withChildren(fieldDeclModifiers) + end astForFieldVariable + + private def astForConstructor(constructorDeclaration: ConstructorDeclaration): Ast = + val constructorNode = createPartialMethod(constructorDeclaration) + .name(io.appthreat.x2cpg.Defines.ConstructorMethodName) + + scope.pushMethodScope(constructorNode, ExpectedType.Void) + val maybeResolved = tryWithSafeStackOverflow(constructorDeclaration.resolve()) + + val parameterAsts = astsForParameterList(constructorDeclaration.getParameters).toList + val paramTypes = argumentTypesForMethodLike(maybeResolved) + val signature = composeSignature(Some(TypeConstants.Void), paramTypes, parameterAsts.size) + val typeFullName = scope.enclosingTypeDeclFullName + val fullName = + composeMethodFullName( + typeFullName.getOrElse(Defines.UnresolvedNamespace), + Defines.ConstructorMethodName, + signature ) - end astForConstructor - - private def thisNodeForMethod( - maybeTypeFullName: Option[String], - lineNumber: Option[Integer] - ): NewMethodParameterIn = - val typeFullName = typeInfoCalc.registerType(maybeTypeFullName.getOrElse(TypeConstants.Any)) - NodeBuilders.newThisParameterNode( - typeFullName = typeFullName, - dynamicTypeHintFullName = maybeTypeFullName.toSeq, - line = lineNumber + val typeNameLookup = fullName.takeWhile(_ != ':').split("\\.").dropRight(1).mkString(".") + constructorNode + .fullName(fullName) + .signature(signature) + if packagesJarMappings.contains(typeNameLookup) then + constructorNode.astParentType(packagesJarMappings.getOrElse( + typeNameLookup, + mutable.Set.empty + ).head) + + parameterAsts.foreach { ast => + ast.root match + case Some(parameter: NewMethodParameterIn) => scope.addParameter(parameter) + case _ => // This should never happen + } + + val thisNode = thisNodeForMethod(typeFullName, line(constructorDeclaration)) + scope.addParameter(thisNode) + val thisAst = Ast(thisNode) + + val bodyAst = astForConstructorBody(Some(constructorDeclaration.getBody)) + val methodReturn = constructorReturnNode(constructorDeclaration) + + val annotationAsts = + constructorDeclaration.getAnnotations.asScala.map(astForAnnotationExpr).toList + + val modifiers = + NewModifier().modifierType(ModifierTypes.CONSTRUCTOR) :: modifiersForMethod( + constructorDeclaration + ).filterNot( + _.modifierType == ModifierTypes.VIRTUAL ) - private def convertAnnotationValueExpr(expr: Expression): Option[Ast] = - expr match - case arrayInit: ArrayInitializerExpr => - val arrayInitNode = NewArrayInitializer() - .code(arrayInit.toString) - val initElementAsts = arrayInit.getValues.asScala.toList.map { value => - convertAnnotationValueExpr(value) - } - - setArgumentIndices(initElementAsts.flatten) - - val returnAst = initElementAsts.foldLeft(Ast(arrayInitNode)) { - case (ast, Some(elementAst)) => - ast.withChild(elementAst) - case (ast, _) => ast - } - Some(returnAst) - - case annotationExpr: AnnotationExpr => - Some(astForAnnotationExpr(annotationExpr)) - - case literalExpr: LiteralExpr => - Some(astForAnnotationLiteralExpr(literalExpr)) - - case _: ClassExpr => - // TODO: Implement for known case - None - - case _: FieldAccessExpr => - // TODO: Implement for known case - None - - case _: BinaryExpr => - // TODO: Implement for known case - None + scope.popScope() + + methodAstWithAnnotations( + constructorNode, + thisAst :: parameterAsts, + bodyAst, + methodReturn, + modifiers, + annotationAsts + ) + end astForConstructor + + private def thisNodeForMethod( + maybeTypeFullName: Option[String], + lineNumber: Option[Integer] + ): NewMethodParameterIn = + val typeFullName = typeInfoCalc.registerType(maybeTypeFullName.getOrElse(TypeConstants.Any)) + NodeBuilders.newThisParameterNode( + typeFullName = typeFullName, + dynamicTypeHintFullName = maybeTypeFullName.toSeq, + line = lineNumber + ) + + private def convertAnnotationValueExpr(expr: Expression): Option[Ast] = + expr match + case arrayInit: ArrayInitializerExpr => + val arrayInitNode = NewArrayInitializer() + .code(arrayInit.toString) + val initElementAsts = arrayInit.getValues.asScala.toList.map { value => + convertAnnotationValueExpr(value) + } - case _: NameExpr => - // TODO: Implement for known case - None + setArgumentIndices(initElementAsts.flatten) - case _ => - logger.debug( - s"convertAnnotationValueExpr not yet implemented for unknown case ${expr.getClass}" - ) - None - - private def astForAnnotationLiteralExpr(literalExpr: LiteralExpr): Ast = - val valueNode = - literalExpr match - case literal: StringLiteralExpr => newAnnotationLiteralNode(literal.getValue) - case literal: IntegerLiteralExpr => newAnnotationLiteralNode(literal.getValue) - case literal: BooleanLiteralExpr => - newAnnotationLiteralNode(java.lang.Boolean.toString(literal.getValue)) - case literal: CharLiteralExpr => newAnnotationLiteralNode(literal.getValue) - case literal: DoubleLiteralExpr => newAnnotationLiteralNode(literal.getValue) - case literal: LongLiteralExpr => newAnnotationLiteralNode(literal.getValue) - case _: NullLiteralExpr => newAnnotationLiteralNode("null") - case literal: TextBlockLiteralExpr => newAnnotationLiteralNode(literal.getValue) - - Ast(valueNode) - - private def exprNameFromStack(expr: Expression): Option[String] = expr match - case annotation: AnnotationExpr => - scope.lookupType(annotation.getNameAsString) - case namedExpr: NodeWithName[?] => - scope.lookupVariableOrType(namedExpr.getNameAsString) - case namedExpr: NodeWithSimpleName[?] => - scope.lookupVariableOrType(namedExpr.getNameAsString) - // JavaParser doesn't handle literals well for some reason - case _: BooleanLiteralExpr => Some("boolean") - case _: CharLiteralExpr => Some("char") - case _: DoubleLiteralExpr => Some("double") - case _: IntegerLiteralExpr => Some("int") - case _: LongLiteralExpr => Some("long") - case _: NullLiteralExpr => Some("null") - case _: StringLiteralExpr => Some("java.lang.String") - case _: TextBlockLiteralExpr => Some("java.lang.String") - case _ => None - - private def expressionReturnTypeFullName(expr: Expression): Option[String] = - - val resolvedTypeOption = tryWithSafeStackOverflow(expr.calculateResolvedType()) match - case Failure(ex) => - ex match - // If ast parser fails to resolve type, try resolving locally by using name - // Precaution when resolving by name, we only want to resolve for case when the expr is solely a MethodCallExpr - // and doesn't have a scope to it - case symbolException: UnsolvedSymbolException => - expr match - case callExpr: MethodCallExpr => - callExpr.getScope.toScala match - case Some(_: Expression) => None - case _ => scope.lookupType(symbolException.getName) - case _ => None - case _ => None - case Success(resolvedType) => typeInfoCalc.fullName(resolvedType) - resolvedTypeOption.orElse(exprNameFromStack(expr)) - - private def guessTypeFullName(initString: String): String = - initString match - case x - if Seq( - "Override", - "Deprecated", - "SuppressWarnings", - "SafeVarargs", - "FunctionalInterface", - "Native" - ).contains(x) => s"java.lang.$x" - case y if y.startsWith("java.") => y - case _ => s"${Defines.UnresolvedNamespace}.${initString}" - - private def astForAnnotationExpr(annotationExpr: AnnotationExpr): Ast = - val fallbackType = guessTypeFullName(annotationExpr.getNameAsString) - val fullName = expressionReturnTypeFullName(annotationExpr).getOrElse(fallbackType) - val code = annotationExpr.toString - val name = annotationExpr.getName.getIdentifier - val node = annotationNode(annotationExpr, code, name, fullName) - annotationExpr match - case _: MarkerAnnotationExpr => - annotationAst(node, List.empty) - case normal: NormalAnnotationExpr => - val assignmentAsts = normal.getPairs.asScala.toList.map { pair => - annotationAssignmentAst( - pair.getName.getIdentifier, - pair.toString, - convertAnnotationValueExpr(pair.getValue).getOrElse(Ast()) - ) - } - annotationAst(node, assignmentAsts) - case single: SingleMemberAnnotationExpr => - val assignmentAsts = List( - annotationAssignmentAst( - "value", - single.getMemberValue.toString, - convertAnnotationValueExpr(single.getMemberValue).getOrElse(Ast()) - ) - ) - annotationAst(node, assignmentAsts) - end match - end astForAnnotationExpr - - private def abstractModifierForCallable( - callableDeclaration: CallableDeclaration[?], - isInterfaceMethod: Boolean - ): Option[NewModifier] = - callableDeclaration match - case methodDeclaration: MethodDeclaration => - Option.when( - methodDeclaration.isAbstract || (isInterfaceMethod && !methodDeclaration.isDefault) - ) { - newModifierNode(ModifierTypes.ABSTRACT) - } + val returnAst = initElementAsts.foldLeft(Ast(arrayInitNode)) { + case (ast, Some(elementAst)) => + ast.withChild(elementAst) + case (ast, _) => ast + } + Some(returnAst) - case _ => None + case annotationExpr: AnnotationExpr => + Some(astForAnnotationExpr(annotationExpr)) - private def modifiersForMethod(methodDeclaration: CallableDeclaration[?]): List[NewModifier] = - val isInterfaceMethod = scope.enclosingTypeDecl.exists(_.code.contains("interface ")) + case literalExpr: LiteralExpr => + Some(astForAnnotationLiteralExpr(literalExpr)) - val abstractModifier = abstractModifierForCallable(methodDeclaration, isInterfaceMethod) + case _: ClassExpr => + // TODO: Implement for known case + None - val staticVirtualModifierType = - if methodDeclaration.isStatic then ModifierTypes.STATIC else ModifierTypes.VIRTUAL - val staticVirtualModifier = Some(newModifierNode(staticVirtualModifierType)) + case _: FieldAccessExpr => + // TODO: Implement for known case + None - val accessModifierType = if methodDeclaration.isPublic then - Some(ModifierTypes.PUBLIC) - else if methodDeclaration.isPrivate then - Some(ModifierTypes.PRIVATE) - else if methodDeclaration.isProtected then - Some(ModifierTypes.PROTECTED) - else if isInterfaceMethod then - // TODO: more robust interface check - Some(ModifierTypes.PUBLIC) - else + case _: BinaryExpr => + // TODO: Implement for known case None - val accessModifier = accessModifierType.map(newModifierNode) - - List(accessModifier, abstractModifier, staticVirtualModifier).flatten - end modifiersForMethod - - private def getIdentifiersForTypeParameters(methodDeclaration: MethodDeclaration) - : List[NewIdentifier] = - methodDeclaration.getTypeParameters.asScala.map { typeParameter => - val name = typeParameter.getNameAsString - val typeFullName = typeParameter.getTypeBound.asScala.headOption - .flatMap(typeInfoCalc.fullName) - .getOrElse(TypeConstants.Object) - typeInfoCalc.registerType(typeFullName) - NewIdentifier().name(name).typeFullName(typeFullName) - }.toList + case _: NameExpr => + // TODO: Implement for known case + None - private def astForMethod(methodDeclaration: MethodDeclaration): Ast = - val methodNode = createPartialMethod(methodDeclaration) + case _ => + logger.debug( + s"convertAnnotationValueExpr not yet implemented for unknown case ${expr.getClass}" + ) + None - val typeParameters = getIdentifiersForTypeParameters(methodDeclaration) + private def astForAnnotationLiteralExpr(literalExpr: LiteralExpr): Ast = + val valueNode = + literalExpr match + case literal: StringLiteralExpr => newAnnotationLiteralNode(literal.getValue) + case literal: IntegerLiteralExpr => newAnnotationLiteralNode(literal.getValue) + case literal: BooleanLiteralExpr => + newAnnotationLiteralNode(java.lang.Boolean.toString(literal.getValue)) + case literal: CharLiteralExpr => newAnnotationLiteralNode(literal.getValue) + case literal: DoubleLiteralExpr => newAnnotationLiteralNode(literal.getValue) + case literal: LongLiteralExpr => newAnnotationLiteralNode(literal.getValue) + case _: NullLiteralExpr => newAnnotationLiteralNode("null") + case literal: TextBlockLiteralExpr => newAnnotationLiteralNode(literal.getValue) + + Ast(valueNode) + + private def exprNameFromStack(expr: Expression): Option[String] = expr match + case annotation: AnnotationExpr => + scope.lookupType(annotation.getNameAsString) + case namedExpr: NodeWithName[?] => + scope.lookupVariableOrType(namedExpr.getNameAsString) + case namedExpr: NodeWithSimpleName[?] => + scope.lookupVariableOrType(namedExpr.getNameAsString) + // JavaParser doesn't handle literals well for some reason + case _: BooleanLiteralExpr => Some("boolean") + case _: CharLiteralExpr => Some("char") + case _: DoubleLiteralExpr => Some("double") + case _: IntegerLiteralExpr => Some("int") + case _: LongLiteralExpr => Some("long") + case _: NullLiteralExpr => Some("null") + case _: StringLiteralExpr => Some("java.lang.String") + case _: TextBlockLiteralExpr => Some("java.lang.String") + case _ => None + + private def expressionReturnTypeFullName(expr: Expression): Option[String] = + + val resolvedTypeOption = tryWithSafeStackOverflow(expr.calculateResolvedType()) match + case Failure(ex) => + ex match + // If ast parser fails to resolve type, try resolving locally by using name + // Precaution when resolving by name, we only want to resolve for case when the expr is solely a MethodCallExpr + // and doesn't have a scope to it + case symbolException: UnsolvedSymbolException => + expr match + case callExpr: MethodCallExpr => + callExpr.getScope.toScala match + case Some(_: Expression) => None + case _ => scope.lookupType(symbolException.getName) + case _ => None + case _ => None + case Success(resolvedType) => typeInfoCalc.fullName(resolvedType) + resolvedTypeOption.orElse(exprNameFromStack(expr)) + + private def guessTypeFullName(initString: String): String = + initString match + case x + if Seq( + "Override", + "Deprecated", + "SuppressWarnings", + "SafeVarargs", + "FunctionalInterface", + "Native" + ).contains(x) => s"java.lang.$x" + case y if y.startsWith("java.") => y + case _ => s"${Defines.UnresolvedNamespace}.${initString}" + + private def astForAnnotationExpr(annotationExpr: AnnotationExpr): Ast = + val fallbackType = guessTypeFullName(annotationExpr.getNameAsString) + val fullName = expressionReturnTypeFullName(annotationExpr).getOrElse(fallbackType) + val code = annotationExpr.toString + val name = annotationExpr.getName.getIdentifier + val node = annotationNode(annotationExpr, code, name, fullName) + annotationExpr match + case _: MarkerAnnotationExpr => + annotationAst(node, List.empty) + case normal: NormalAnnotationExpr => + val assignmentAsts = normal.getPairs.asScala.toList.map { pair => + annotationAssignmentAst( + pair.getName.getIdentifier, + pair.toString, + convertAnnotationValueExpr(pair.getValue).getOrElse(Ast()) + ) + } + annotationAst(node, assignmentAsts) + case single: SingleMemberAnnotationExpr => + val assignmentAsts = List( + annotationAssignmentAst( + "value", + single.getMemberValue.toString, + convertAnnotationValueExpr(single.getMemberValue).getOrElse(Ast()) + ) + ) + annotationAst(node, assignmentAsts) + end match + end astForAnnotationExpr + + private def abstractModifierForCallable( + callableDeclaration: CallableDeclaration[?], + isInterfaceMethod: Boolean + ): Option[NewModifier] = + callableDeclaration match + case methodDeclaration: MethodDeclaration => + Option.when( + methodDeclaration.isAbstract || (isInterfaceMethod && !methodDeclaration.isDefault) + ) { + newModifierNode(ModifierTypes.ABSTRACT) + } - val maybeResolved = tryWithSafeStackOverflow(methodDeclaration.resolve()) - val expectedReturnType = Try(symbolSolver.toResolvedType( - methodDeclaration.getType, - classOf[ResolvedType] - )).toOption - val simpleMethodReturnType = methodDeclaration.getTypeAsString() - val returnTypeFullName = expectedReturnType + case _ => None + + private def modifiersForMethod(methodDeclaration: CallableDeclaration[?]): List[NewModifier] = + val isInterfaceMethod = scope.enclosingTypeDecl.exists(_.code.contains("interface ")) + + val abstractModifier = abstractModifierForCallable(methodDeclaration, isInterfaceMethod) + + val staticVirtualModifierType = + if methodDeclaration.isStatic then ModifierTypes.STATIC else ModifierTypes.VIRTUAL + val staticVirtualModifier = Some(newModifierNode(staticVirtualModifierType)) + + val accessModifierType = if methodDeclaration.isPublic then + Some(ModifierTypes.PUBLIC) + else if methodDeclaration.isPrivate then + Some(ModifierTypes.PRIVATE) + else if methodDeclaration.isProtected then + Some(ModifierTypes.PROTECTED) + else if isInterfaceMethod then + // TODO: more robust interface check + Some(ModifierTypes.PUBLIC) + else + None + val accessModifier = accessModifierType.map(newModifierNode) + + List(accessModifier, abstractModifier, staticVirtualModifier).flatten + end modifiersForMethod + + private def getIdentifiersForTypeParameters(methodDeclaration: MethodDeclaration) + : List[NewIdentifier] = + methodDeclaration.getTypeParameters.asScala.map { typeParameter => + val name = typeParameter.getNameAsString + val typeFullName = typeParameter.getTypeBound.asScala.headOption .flatMap(typeInfoCalc.fullName) - .orElse(scope.lookupType(simpleMethodReturnType)) - .orElse(typeParameters.find(_.name == simpleMethodReturnType).map(_.typeFullName)) - - scope.pushMethodScope(methodNode, ExpectedType(returnTypeFullName, expectedReturnType)) - typeParameters.foreach { typeParameter => - scope.addType(typeParameter.name, typeParameter.typeFullName) - } - - val parameterAsts = astsForParameterList(methodDeclaration.getParameters) - val parameterTypes = argumentTypesForMethodLike(maybeResolved) - val signature = composeSignature(returnTypeFullName, parameterTypes, parameterAsts.size) - val namespaceName = scope.enclosingTypeDeclFullName.getOrElse(Defines.UnresolvedNamespace) - val methodFullName = - composeMethodFullName(namespaceName, methodDeclaration.getNameAsString, signature) - - methodNode - .fullName(methodFullName) - .signature(signature) - val typeNameLookup = - methodFullName.takeWhile(_ != ':').split("\\.").dropRight(1).mkString(".") - if packagesJarMappings != null && packagesJarMappings.contains(typeNameLookup) then - methodNode.astParentType(packagesJarMappings.getOrElse( - typeNameLookup, - mutable.Set.empty - ).head) - val thisNode = Option.when(!methodDeclaration.isStatic) { - val typeFullName = scope.enclosingTypeDeclFullName - thisNodeForMethod(typeFullName, line(methodDeclaration)) - } - val thisAst = thisNode.map(Ast(_)).toList + .getOrElse(TypeConstants.Object) + typeInfoCalc.registerType(typeFullName) - thisNode.foreach { node => - scope.addParameter(node) + NewIdentifier().name(name).typeFullName(typeFullName) + }.toList + + private def astForMethod(methodDeclaration: MethodDeclaration): Ast = + val methodNode = createPartialMethod(methodDeclaration) + + val typeParameters = getIdentifiersForTypeParameters(methodDeclaration) + + val maybeResolved = tryWithSafeStackOverflow(methodDeclaration.resolve()) + val expectedReturnType = Try(symbolSolver.toResolvedType( + methodDeclaration.getType, + classOf[ResolvedType] + )).toOption + val simpleMethodReturnType = methodDeclaration.getTypeAsString() + val returnTypeFullName = expectedReturnType + .flatMap(typeInfoCalc.fullName) + .orElse(scope.lookupType(simpleMethodReturnType)) + .orElse(typeParameters.find(_.name == simpleMethodReturnType).map(_.typeFullName)) + + scope.pushMethodScope(methodNode, ExpectedType(returnTypeFullName, expectedReturnType)) + typeParameters.foreach { typeParameter => + scope.addType(typeParameter.name, typeParameter.typeFullName) + } + + val parameterAsts = astsForParameterList(methodDeclaration.getParameters) + val parameterTypes = argumentTypesForMethodLike(maybeResolved) + val signature = composeSignature(returnTypeFullName, parameterTypes, parameterAsts.size) + val namespaceName = scope.enclosingTypeDeclFullName.getOrElse(Defines.UnresolvedNamespace) + val methodFullName = + composeMethodFullName(namespaceName, methodDeclaration.getNameAsString, signature) + + methodNode + .fullName(methodFullName) + .signature(signature) + val typeNameLookup = + methodFullName.takeWhile(_ != ':').split("\\.").dropRight(1).mkString(".") + if packagesJarMappings != null && packagesJarMappings.contains(typeNameLookup) then + methodNode.astParentType(packagesJarMappings.getOrElse( + typeNameLookup, + mutable.Set.empty + ).head) + val thisNode = Option.when(!methodDeclaration.isStatic) { + val typeFullName = scope.enclosingTypeDeclFullName + thisNodeForMethod(typeFullName, line(methodDeclaration)) + } + val thisAst = thisNode.map(Ast(_)).toList + + thisNode.foreach { node => + scope.addParameter(node) + } + + val bodyAst = methodDeclaration.getBody.toScala.map(astForBlockStatement(_)).getOrElse(Ast( + NewBlock() + )) + val methodReturn = newMethodReturnNode( + returnTypeFullName.getOrElse(TypeConstants.Any), + None, + line(methodDeclaration.getType), + column(methodDeclaration.getType) + ) + + val annotationAsts = + methodDeclaration.getAnnotations.asScala.map(astForAnnotationExpr).toSeq + + val modifiers = modifiersForMethod(methodDeclaration) + + scope.popScope() + + methodAstWithAnnotations( + methodNode, + thisAst ++ parameterAsts, + bodyAst, + methodReturn, + modifiers, + annotationAsts + ) + end astForMethod + + private def constructorReturnNode(constructorDeclaration: ConstructorDeclaration) + : NewMethodReturn = + val line = constructorDeclaration.getEnd.map(x => Integer.valueOf(x.line)).toScala + val column = constructorDeclaration.getEnd.map(x => Integer.valueOf(x.column)).toScala + newMethodReturnNode(TypeConstants.Void, None, line, column) + + /** Constructor and Method declarations share a lot of fields, so this method adds the fields they + * have in common. `fullName` and `signature` are omitted + */ + private def createPartialMethod(declaration: CallableDeclaration[?]): NewMethod = + val code = declaration.getDeclarationAsString.trim + val columnNumber = declaration.getBegin.map(x => Integer.valueOf(x.column)).toScala + val endLine = declaration.getEnd.map(x => Integer.valueOf(x.line)).toScala + val endColumn = declaration.getEnd.map(x => Integer.valueOf(x.column)).toScala + + val methodNode = NewMethod() + .name(declaration.getNameAsString) + .code(code) + .isExternal(false) + .filename(filename) + .lineNumber(line(declaration)) + .columnNumber(columnNumber) + .lineNumberEnd(endLine) + .columnNumberEnd(endColumn) + + methodNode + + private def astForConstructorBody(body: Option[BlockStmt]): Ast = + val containsThisInvocation = + body + .flatMap(_.getStatements.asScala.headOption) + .collect { case e: ExplicitConstructorInvocationStmt => e } + .exists(_.isThis) + + val memberInitializers = + if containsThisInvocation then + Seq.empty + else + scope.memberInitializers + + body match + case Some(b) => astForBlockStatement(b, prefixAsts = memberInitializers) + + case None => Ast(NewBlock()).withChildren(memberInitializers) + + private def astsForLabeledStatement(stmt: LabeledStmt): Seq[Ast] = + val jumpTargetAst = Ast(NewJumpTarget().name(stmt.getLabel.toString)) + val stmtAst = astsForStatement(stmt.getStatement).toList + + jumpTargetAst :: stmtAst + + private def astForThrow(stmt: ThrowStmt): Ast = + val throwNode = NewCall() + .name(".throw") + .methodFullName(".throw") + .lineNumber(line(stmt)) + .columnNumber(column(stmt)) + .code(stmt.toString()) + .dispatchType(DispatchTypes.STATIC_DISPATCH) + + val args = astsForExpression(stmt.getExpression, ExpectedType.empty) + + callAst(throwNode, args) + + private def astForCatchClause(catchClause: CatchClause): Ast = + astForBlockStatement(catchClause.getBody) + + private def astsForTry(stmt: TryStmt): Seq[Ast] = + val tryNode = NewControlStructure() + .controlStructureType(ControlStructureTypes.TRY) + .code("try") + .lineNumber(line(stmt)) + .columnNumber(column(stmt)) + + val resources = stmt.getResources.asScala.flatMap(astsForExpression( + _, + expectedType = ExpectedType.empty + )).toList + val tryAst = astForBlockStatement(stmt.getTryBlock, codeStr = "try") + val catchAsts = stmt.getCatchClauses.asScala.map(astForCatchClause) + val catchBlock = Option + .when(catchAsts.nonEmpty) { + Ast(NewBlock().code("catch")).withChildren(catchAsts) } + .toList + val finallyAst = + stmt.getFinallyBlock.toScala.map(astForBlockStatement(_, "finally")).toList + + val controlStructureAst = Ast(tryNode) + .withChild(tryAst) + .withChildren(catchBlock) + .withChildren(finallyAst) + + resources.appended(controlStructureAst) + end astsForTry + + private def astsForStatement(statement: Statement): Seq[Ast] = + // TODO: Implement missing handlers + // case _: LocalClassDeclarationStmt => Seq() + // case _: LocalRecordDeclarationStmt => Seq() + // case _: YieldStmt => Seq() + statement match + case x: ExplicitConstructorInvocationStmt => + Seq(astForExplicitConstructorInvocation(x)) + case x: AssertStmt => Seq(astForAssertStatement(x)) + case x: BlockStmt => Seq(astForBlockStatement(x)) + case x: BreakStmt => Seq(astForBreakStatement(x)) + case x: ContinueStmt => Seq(astForContinueStatement(x)) + case x: DoStmt => Seq(astForDo(x)) + case _: EmptyStmt => Seq() // Intentionally skipping this + case x: ExpressionStmt => astsForExpression(x.getExpression, ExpectedType.Void) + case x: ForEachStmt => astForForEach(x) + case x: ForStmt => Seq(astForFor(x)) + case x: IfStmt => Seq(astForIf(x)) + case x: LabeledStmt => astsForLabeledStatement(x) + case x: ReturnStmt => Seq(astForReturnNode(x)) + case x: SwitchStmt => Seq(astForSwitchStatement(x)) + case x: SynchronizedStmt => Seq(astForSynchronizedStatement(x)) + case x: ThrowStmt => Seq(astForThrow(x)) + case x: TryStmt => astsForTry(x) + case x: WhileStmt => Seq(astForWhile(x)) + // case x: LocalClassDeclarationStmt => Seq(astForLocalClassDeclarationStmt(x)) + case x => + logger.debug( + s"Attempting to generate AST for unknown statement of type ${x.getClass}" + ) + Seq(unknownAst(x)) - val bodyAst = methodDeclaration.getBody.toScala.map(astForBlockStatement(_)).getOrElse(Ast( - NewBlock() - )) - val methodReturn = newMethodReturnNode( - returnTypeFullName.getOrElse(TypeConstants.Any), - None, - line(methodDeclaration.getType), - column(methodDeclaration.getType) - ) - - val annotationAsts = - methodDeclaration.getAnnotations.asScala.map(astForAnnotationExpr).toSeq - - val modifiers = modifiersForMethod(methodDeclaration) - - scope.popScope() - - methodAstWithAnnotations( - methodNode, - thisAst ++ parameterAsts, - bodyAst, - methodReturn, - modifiers, - annotationAsts - ) - end astForMethod - - private def constructorReturnNode(constructorDeclaration: ConstructorDeclaration) - : NewMethodReturn = - val line = constructorDeclaration.getEnd.map(x => Integer.valueOf(x.line)).toScala - val column = constructorDeclaration.getEnd.map(x => Integer.valueOf(x.column)).toScala - newMethodReturnNode(TypeConstants.Void, None, line, column) - - /** Constructor and Method declarations share a lot of fields, so this method adds the fields - * they have in common. `fullName` and `signature` are omitted - */ - private def createPartialMethod(declaration: CallableDeclaration[?]): NewMethod = - val code = declaration.getDeclarationAsString.trim - val columnNumber = declaration.getBegin.map(x => Integer.valueOf(x.column)).toScala - val endLine = declaration.getEnd.map(x => Integer.valueOf(x.line)).toScala - val endColumn = declaration.getEnd.map(x => Integer.valueOf(x.column)).toScala - - val methodNode = NewMethod() - .name(declaration.getNameAsString) - .code(code) - .isExternal(false) - .filename(filename) - .lineNumber(line(declaration)) - .columnNumber(columnNumber) - .lineNumberEnd(endLine) - .columnNumberEnd(endColumn) - - methodNode - - private def astForConstructorBody(body: Option[BlockStmt]): Ast = - val containsThisInvocation = - body - .flatMap(_.getStatements.asScala.headOption) - .collect { case e: ExplicitConstructorInvocationStmt => e } - .exists(_.isThis) - - val memberInitializers = - if containsThisInvocation then - Seq.empty - else - scope.memberInitializers - - body match - case Some(b) => astForBlockStatement(b, prefixAsts = memberInitializers) - - case None => Ast(NewBlock()).withChildren(memberInitializers) + private def astForElse(maybeStmt: Option[Statement]): Option[Ast] = + maybeStmt.map { stmt => + val elseAsts = astsForStatement(stmt) - private def astsForLabeledStatement(stmt: LabeledStmt): Seq[Ast] = - val jumpTargetAst = Ast(NewJumpTarget().name(stmt.getLabel.toString)) - val stmtAst = astsForStatement(stmt.getStatement).toList + val elseNode = + NewControlStructure() + .controlStructureType(ControlStructureTypes.ELSE) + .lineNumber(line(stmt)) + .columnNumber(column(stmt)) + .code("else") - jumpTargetAst :: stmtAst + Ast(elseNode).withChildren(elseAsts) + } - private def astForThrow(stmt: ThrowStmt): Ast = - val throwNode = NewCall() - .name(".throw") - .methodFullName(".throw") + def astForIf(stmt: IfStmt): Ast = + val ifNode = + NewControlStructure() + .controlStructureType(ControlStructureTypes.IF) .lineNumber(line(stmt)) .columnNumber(column(stmt)) - .code(stmt.toString()) - .dispatchType(DispatchTypes.STATIC_DISPATCH) - - val args = astsForExpression(stmt.getExpression, ExpectedType.empty) - - callAst(throwNode, args) - - private def astForCatchClause(catchClause: CatchClause): Ast = - astForBlockStatement(catchClause.getBody) - - private def astsForTry(stmt: TryStmt): Seq[Ast] = - val tryNode = NewControlStructure() - .controlStructureType(ControlStructureTypes.TRY) - .code("try") + .code(s"if (${stmt.getCondition.toString})") + + val conditionAst = + astsForExpression(stmt.getCondition, ExpectedType.Boolean).headOption.toList + + val thenAsts = astsForStatement(stmt.getThenStmt) + val elseAst = astForElse(stmt.getElseStmt.toScala).toList + + val ast = Ast(ifNode) + .withChildren(conditionAst) + .withChildren(thenAsts) + .withChildren(elseAst) + + conditionAst.flatMap(_.root.toList) match + case r :: Nil => + ast.withConditionEdge(ifNode, r) + case _ => + ast + end astForIf + + def astForWhile(stmt: WhileStmt): Ast = + val conditionAst = astsForExpression(stmt.getCondition, ExpectedType.Boolean).headOption + val stmtAsts = astsForStatement(stmt.getBody) + val code = s"while (${stmt.getCondition.toString})" + val lineNumber = line(stmt) + val columnNumber = column(stmt) + + whileAst(conditionAst, stmtAsts, Some(code), lineNumber, columnNumber) + + private def astForDo(stmt: DoStmt): Ast = + val conditionAst = astsForExpression(stmt.getCondition, ExpectedType.Boolean).headOption + val stmtAsts = astsForStatement(stmt.getBody) + val code = s"do {...} while (${stmt.getCondition.toString})" + val lineNumber = line(stmt) + val columnNumber = column(stmt) + + doWhileAst(conditionAst, stmtAsts, Some(code), lineNumber, columnNumber) + + private def astForBreakStatement(stmt: BreakStmt): Ast = + val node = NewControlStructure() + .controlStructureType(ControlStructureTypes.BREAK) + .lineNumber(line(stmt)) + .columnNumber(column(stmt)) + .code(stmt.toString) + Ast(node) + + private def astForContinueStatement(stmt: ContinueStmt): Ast = + val node = NewControlStructure() + .controlStructureType(ControlStructureTypes.CONTINUE) + .lineNumber(line(stmt)) + .columnNumber(column(stmt)) + .code(stmt.toString) + Ast(node) + + private def getForCode(stmt: ForStmt): String = + val init = stmt.getInitialization.asScala.map(_.toString).mkString(", ") + val compare = stmt.getCompare.toScala.map(_.toString) + val update = stmt.getUpdate.asScala.map(_.toString).mkString(", ") + s"for ($init; $compare; $update)" + def astForFor(stmt: ForStmt): Ast = + val forNode = + NewControlStructure() + .controlStructureType(ControlStructureTypes.FOR) + .code(getForCode(stmt)) .lineNumber(line(stmt)) .columnNumber(column(stmt)) - val resources = stmt.getResources.asScala.flatMap(astsForExpression( + val initAsts = + stmt.getInitialization.asScala.flatMap(astsForExpression( _, expectedType = ExpectedType.empty - )).toList - val tryAst = astForBlockStatement(stmt.getTryBlock, codeStr = "try") - val catchAsts = stmt.getCatchClauses.asScala.map(astForCatchClause) - val catchBlock = Option - .when(catchAsts.nonEmpty) { - Ast(NewBlock().code("catch")).withChildren(catchAsts) - } - .toList - val finallyAst = - stmt.getFinallyBlock.toScala.map(astForBlockStatement(_, "finally")).toList - - val controlStructureAst = Ast(tryNode) - .withChild(tryAst) - .withChildren(catchBlock) - .withChildren(finallyAst) - - resources.appended(controlStructureAst) - end astsForTry - - private def astsForStatement(statement: Statement): Seq[Ast] = - // TODO: Implement missing handlers - // case _: LocalClassDeclarationStmt => Seq() - // case _: LocalRecordDeclarationStmt => Seq() - // case _: YieldStmt => Seq() - statement match - case x: ExplicitConstructorInvocationStmt => - Seq(astForExplicitConstructorInvocation(x)) - case x: AssertStmt => Seq(astForAssertStatement(x)) - case x: BlockStmt => Seq(astForBlockStatement(x)) - case x: BreakStmt => Seq(astForBreakStatement(x)) - case x: ContinueStmt => Seq(astForContinueStatement(x)) - case x: DoStmt => Seq(astForDo(x)) - case _: EmptyStmt => Seq() // Intentionally skipping this - case x: ExpressionStmt => astsForExpression(x.getExpression, ExpectedType.Void) - case x: ForEachStmt => astForForEach(x) - case x: ForStmt => Seq(astForFor(x)) - case x: IfStmt => Seq(astForIf(x)) - case x: LabeledStmt => astsForLabeledStatement(x) - case x: ReturnStmt => Seq(astForReturnNode(x)) - case x: SwitchStmt => Seq(astForSwitchStatement(x)) - case x: SynchronizedStmt => Seq(astForSynchronizedStatement(x)) - case x: ThrowStmt => Seq(astForThrow(x)) - case x: TryStmt => astsForTry(x) - case x: WhileStmt => Seq(astForWhile(x)) - // case x: LocalClassDeclarationStmt => Seq(astForLocalClassDeclarationStmt(x)) - case x => - logger.debug( - s"Attempting to generate AST for unknown statement of type ${x.getClass}" - ) - Seq(unknownAst(x)) - - private def astForElse(maybeStmt: Option[Statement]): Option[Ast] = - maybeStmt.map { stmt => - val elseAsts = astsForStatement(stmt) - - val elseNode = - NewControlStructure() - .controlStructureType(ControlStructureTypes.ELSE) - .lineNumber(line(stmt)) - .columnNumber(column(stmt)) - .code("else") - - Ast(elseNode).withChildren(elseAsts) - } - - def astForIf(stmt: IfStmt): Ast = - val ifNode = - NewControlStructure() - .controlStructureType(ControlStructureTypes.IF) - .lineNumber(line(stmt)) - .columnNumber(column(stmt)) - .code(s"if (${stmt.getCondition.toString})") - - val conditionAst = - astsForExpression(stmt.getCondition, ExpectedType.Boolean).headOption.toList - - val thenAsts = astsForStatement(stmt.getThenStmt) - val elseAst = astForElse(stmt.getElseStmt.toScala).toList - - val ast = Ast(ifNode) - .withChildren(conditionAst) - .withChildren(thenAsts) - .withChildren(elseAst) - - conditionAst.flatMap(_.root.toList) match - case r :: Nil => - ast.withConditionEdge(ifNode, r) - case _ => - ast - end astForIf - - def astForWhile(stmt: WhileStmt): Ast = - val conditionAst = astsForExpression(stmt.getCondition, ExpectedType.Boolean).headOption - val stmtAsts = astsForStatement(stmt.getBody) - val code = s"while (${stmt.getCondition.toString})" - val lineNumber = line(stmt) - val columnNumber = column(stmt) - - whileAst(conditionAst, stmtAsts, Some(code), lineNumber, columnNumber) - - private def astForDo(stmt: DoStmt): Ast = - val conditionAst = astsForExpression(stmt.getCondition, ExpectedType.Boolean).headOption - val stmtAsts = astsForStatement(stmt.getBody) - val code = s"do {...} while (${stmt.getCondition.toString})" - val lineNumber = line(stmt) - val columnNumber = column(stmt) - - doWhileAst(conditionAst, stmtAsts, Some(code), lineNumber, columnNumber) - - private def astForBreakStatement(stmt: BreakStmt): Ast = - val node = NewControlStructure() - .controlStructureType(ControlStructureTypes.BREAK) - .lineNumber(line(stmt)) - .columnNumber(column(stmt)) - .code(stmt.toString) - Ast(node) - - private def astForContinueStatement(stmt: ContinueStmt): Ast = - val node = NewControlStructure() - .controlStructureType(ControlStructureTypes.CONTINUE) - .lineNumber(line(stmt)) - .columnNumber(column(stmt)) - .code(stmt.toString) - Ast(node) - - private def getForCode(stmt: ForStmt): String = - val init = stmt.getInitialization.asScala.map(_.toString).mkString(", ") - val compare = stmt.getCompare.toScala.map(_.toString) - val update = stmt.getUpdate.asScala.map(_.toString).mkString(", ") - s"for ($init; $compare; $update)" - def astForFor(stmt: ForStmt): Ast = - val forNode = - NewControlStructure() - .controlStructureType(ControlStructureTypes.FOR) - .code(getForCode(stmt)) - .lineNumber(line(stmt)) - .columnNumber(column(stmt)) - - val initAsts = - stmt.getInitialization.asScala.flatMap(astsForExpression( - _, - expectedType = ExpectedType.empty - )) + )) - val compareAsts = stmt.getCompare.toScala.toList.flatMap { - astsForExpression(_, ExpectedType.Boolean) - } + val compareAsts = stmt.getCompare.toScala.toList.flatMap { + astsForExpression(_, ExpectedType.Boolean) + } + + val updateAsts = stmt.getUpdate.asScala.toList.flatMap { + astsForExpression(_, ExpectedType.empty) + } + + val stmtAsts = + astsForStatement(stmt.getBody) + + val ast = Ast(forNode) + .withChildren(initAsts) + .withChildren(compareAsts) + .withChildren(updateAsts) + .withChildren(stmtAsts) + + compareAsts.flatMap(_.root) match + case c :: Nil => + ast.withConditionEdge(forNode, c) + case _ => ast + end astForFor + + private def iterableAssignAstsForNativeForEach( + iterableExpression: Expression, + iterableType: Option[String] + ): (NodeTypeInfo, Seq[Ast]) = + val lineNo = line(iterableExpression) + val expectedType = ExpectedType(iterableType) + + val iterableAst = astsForExpression(iterableExpression, expectedType = expectedType) match + case Nil => + logger.debug( + s"Could not create AST for iterable expr $iterableExpression: $filename:l$lineNo" + ) + Ast() + case iterableAstHead :: Nil => iterableAstHead + case iterableAsts => + logger.debug( + s"Found multiple ASTS for iterable expr $iterableExpression: $filename:l$lineNo\nDropping all but the first!" + ) + iterableAsts.head - val updateAsts = stmt.getUpdate.asScala.toList.flatMap { - astsForExpression(_, ExpectedType.empty) - } + val iterableName = nextIterableName() + val iterableLocalNode = + localNode(iterableExpression, iterableName, iterableName, iterableType.getOrElse("ANY")) + val iterableLocalAst = Ast(iterableLocalNode) - val stmtAsts = - astsForStatement(stmt.getBody) - - val ast = Ast(forNode) - .withChildren(initAsts) - .withChildren(compareAsts) - .withChildren(updateAsts) - .withChildren(stmtAsts) - - compareAsts.flatMap(_.root) match - case c :: Nil => - ast.withConditionEdge(forNode, c) - case _ => ast - end astForFor - - private def iterableAssignAstsForNativeForEach( - iterableExpression: Expression, - iterableType: Option[String] - ): (NodeTypeInfo, Seq[Ast]) = - val lineNo = line(iterableExpression) - val expectedType = ExpectedType(iterableType) - - val iterableAst = astsForExpression(iterableExpression, expectedType = expectedType) match - case Nil => - logger.debug( - s"Could not create AST for iterable expr $iterableExpression: $filename:l$lineNo" - ) - Ast() - case iterableAstHead :: Nil => iterableAstHead - case iterableAsts => - logger.debug( - s"Found multiple ASTS for iterable expr $iterableExpression: $filename:l$lineNo\nDropping all but the first!" - ) - iterableAsts.head - - val iterableName = nextIterableName() - val iterableLocalNode = - localNode(iterableExpression, iterableName, iterableName, iterableType.getOrElse("ANY")) - val iterableLocalAst = Ast(iterableLocalNode) - - val iterableAssignNode = - newOperatorCallNode( - Operators.assignment, - code = "", - line = lineNo, - typeFullName = iterableType - ) - val iterableAssignIdentifier = - identifierNode( - iterableExpression, - iterableName, - iterableName, - iterableType.getOrElse("ANY") - ) - val iterableAssignArgs = List(Ast(iterableAssignIdentifier), iterableAst) - val iterableAssignAst = - callAst(iterableAssignNode, iterableAssignArgs) - .withRefEdge(iterableAssignIdentifier, iterableLocalNode) - - ( - NodeTypeInfo( - iterableLocalNode, - iterableLocalNode.name, - Some(iterableLocalNode.typeFullName) - ), - List(iterableLocalAst, iterableAssignAst) - ) - end iterableAssignAstsForNativeForEach - - private def nativeForEachIdxLocalNode(lineNo: Option[Integer]): NewLocal = - val idxName = nextIndexName() - val typeFullName = TypeConstants.Int - val idxLocal = - NewLocal() - .name(idxName) - .typeFullName(typeFullName) - .code(idxName) - .lineNumber(lineNo) - scope.addLocal(idxLocal) - idxLocal - - private def nativeForEachIdxInitializerAst(lineNo: Option[Integer], idxLocal: NewLocal): Ast = - val idxName = idxLocal.name - val idxInitializerCallNode = newOperatorCallNode( + val iterableAssignNode = + newOperatorCallNode( Operators.assignment, - code = s"int $idxName = 0", + code = "", line = lineNo, - typeFullName = Some(TypeConstants.Int) - ) - val idxIdentifierArg = newIdentifierNode(idxName, idxLocal.typeFullName) - val zeroLiteral = - NewLiteral() - .code("0") - .typeFullName(TypeConstants.Int) - .lineNumber(lineNo) - val idxInitializerArgAsts = List(Ast(idxIdentifierArg), Ast(zeroLiteral)) - callAst(idxInitializerCallNode, idxInitializerArgAsts) - .withRefEdge(idxIdentifierArg, idxLocal) - - private def nativeForEachCompareAst( - lineNo: Option[Integer], - iterableSource: NodeTypeInfo, - idxLocal: NewLocal - ): Ast = - val idxName = idxLocal.name - - val compareNode = newOperatorCallNode( - Operators.lessThan, - code = s"$idxName < ${iterableSource.name}.length", - typeFullName = Some(TypeConstants.Boolean), - line = lineNo - ) - val comparisonIdxIdentifier = newIdentifierNode(idxName, idxLocal.typeFullName) - val comparisonFieldAccess = newOperatorCallNode( - Operators.fieldAccess, - code = s"${iterableSource.name}.length", - typeFullName = Some(TypeConstants.Int), - line = lineNo + typeFullName = iterableType ) - val fieldAccessIdentifier = - newIdentifierNode(iterableSource.name, iterableSource.typeFullName.getOrElse("ANY")) - val fieldAccessFieldIdentifier = newFieldIdentifierNode("length", lineNo) - val fieldAccessArgs = List(fieldAccessIdentifier, fieldAccessFieldIdentifier).map(Ast(_)) - val fieldAccessAst = callAst(comparisonFieldAccess, fieldAccessArgs) - val compareArgs = List(Ast(comparisonIdxIdentifier), fieldAccessAst) - - // TODO: This is a workaround for a crash when looping over statically imported members. Handle those properly. - val iterableSourceNode = localParamOrMemberFromNode(iterableSource) - - callAst(compareNode, compareArgs) - .withRefEdge(comparisonIdxIdentifier, idxLocal) - .withRefEdges(fieldAccessIdentifier, iterableSourceNode.toList) - end nativeForEachCompareAst - - private def nativeForEachIncrementAst(lineNo: Option[Integer], idxLocal: NewLocal): Ast = - val incrementNode = newOperatorCallNode( - Operators.postIncrement, - code = s"${idxLocal.name}++", - typeFullName = Some(TypeConstants.Int), - line = lineNo + val iterableAssignIdentifier = + identifierNode( + iterableExpression, + iterableName, + iterableName, + iterableType.getOrElse("ANY") ) - val incrementArg = newIdentifierNode(idxLocal.name, idxLocal.typeFullName) - val incrementArgAst = Ast(incrementArg) - callAst(incrementNode, List(incrementArgAst)) - .withRefEdge(incrementArg, idxLocal) - - private def variableLocalForForEachBody(stmt: ForEachStmt): NewLocal = - val lineNo = line(stmt) - // Create item local - val maybeVariable = stmt.getVariable.getVariables.asScala.toList match - case Nil => - logger.debug(s"ForEach statement has empty variable list: $filename$lineNo") - None - case variable :: Nil => Some(variable) - case variable :: _ => - logger.debug( - s"ForEach statement defines multiple variables. Dropping all but the first: $filename$lineNo" - ) - Some(variable) - - val partialLocalNode = NewLocal().lineNumber(lineNo) - - maybeVariable match - case Some(variable) => - val name = variable.getNameAsString - val typeFullName = typeInfoCalc.fullName(variable.getType).getOrElse("ANY") - val localNode = partialLocalNode - .name(name) - .code(variable.getNameAsString) - .typeFullName(typeFullName) - - scope.addLocal(localNode) - localNode - - case None => - // Returning partialLocalNode here is fine since getting to this case means everything is broken anyways :) - partialLocalNode - end variableLocalForForEachBody - - private def localParamOrMemberFromNode(nodeTypeInfo: NodeTypeInfo): Option[NewNode] = - nodeTypeInfo.node match - case localNode: NewLocal => Some(localNode) - case memberNode: NewMember => Some(memberNode) - case parameterNode: NewMethodParameterIn => Some(parameterNode) - case _ => None - private def variableAssignForNativeForEachBody( - variableLocal: NewLocal, - idxLocal: NewLocal, - iterable: NodeTypeInfo - ): Ast = - // Everything will be on the same line as the `for` statement, but this is the most useful - // solution for debugging. - val lineNo = variableLocal.lineNumber - val varAssignNode = - newOperatorCallNode( - Operators.assignment, - PropertyDefaults.Code, - Some(variableLocal.typeFullName), - lineNo - ) - - val targetNode = newIdentifierNode(variableLocal.name, variableLocal.typeFullName) - - val indexAccessTypeFullName = iterable.typeFullName.map(_.replaceAll(raw"\[]", "")) - val indexAccess = - newOperatorCallNode( - Operators.indexAccess, - PropertyDefaults.Code, - indexAccessTypeFullName, - lineNo - ) - - val indexAccessIdentifier = - newIdentifierNode(iterable.name, iterable.typeFullName.getOrElse("ANY")) - val indexAccessIndex = newIdentifierNode(idxLocal.name, idxLocal.typeFullName) - - val indexAccessArgsAsts = List(indexAccessIdentifier, indexAccessIndex).map(Ast(_)) - val indexAccessAst = callAst(indexAccess, indexAccessArgsAsts) - - val iterableSourceNode = localParamOrMemberFromNode(iterable) - - val assignArgsAsts = List(Ast(targetNode), indexAccessAst) - callAst(varAssignNode, assignArgsAsts) - .withRefEdge(targetNode, variableLocal) - .withRefEdges(indexAccessIdentifier, iterableSourceNode.toList) - .withRefEdge(indexAccessIndex, idxLocal) - end variableAssignForNativeForEachBody - - private def nativeForEachBodyAst( - stmt: ForEachStmt, - idxLocal: NewLocal, - iterable: NodeTypeInfo - ): Ast = - val variableLocal = variableLocalForForEachBody(stmt) - val variableLocalAst = Ast(variableLocal) - val variableAssignAst = - variableAssignForNativeForEachBody(variableLocal, idxLocal, iterable) - - stmt.getBody match - case block: BlockStmt => - astForBlockStatement(block, prefixAsts = List(variableLocalAst, variableAssignAst)) - case statement => - val stmtAsts = astsForStatement(statement) - val blockNode = NewBlock().lineNumber(variableLocal.lineNumber) - Ast(blockNode) - .withChild(variableLocalAst) - .withChild(variableAssignAst) - .withChildren(stmtAsts) - end nativeForEachBodyAst - - private def astsForNativeForEach(stmt: ForEachStmt, iterableType: Option[String]): Seq[Ast] = - - // This is ugly, but for a case like `for (int x : new int[] { ... })` this creates a new LOCAL - // with the assignment `int[] $iterLocal0 = new int[] { ... }` before the FOR loop. - // TODO: Fix this - val (iterableSource: NodeTypeInfo, tempIterableInitAsts) = stmt.getIterable match - case nameExpr: NameExpr => - scope.lookupVariable(nameExpr.getNameAsString).asNodeInfoOption match - // If this is not the case, then the code is broken (iterable not in scope). - case Some(nodeTypeInfo) => (nodeTypeInfo, Nil) - case _ => iterableAssignAstsForNativeForEach(nameExpr, iterableType) - case iterableExpr => iterableAssignAstsForNativeForEach(iterableExpr, iterableType) - - val forNode = NewControlStructure() - .controlStructureType(ControlStructureTypes.FOR) - - val lineNo = line(stmt) - - val idxLocal = nativeForEachIdxLocalNode(lineNo) - val idxInitializerAst = nativeForEachIdxInitializerAst(lineNo, idxLocal) - // TODO next: pass NodeTypeInfo around - val compareAst = nativeForEachCompareAst(lineNo, iterableSource, idxLocal) - val incrementAst = nativeForEachIncrementAst(lineNo, idxLocal) - val bodyAst = nativeForEachBodyAst(stmt, idxLocal, iterableSource) - - val forAst = Ast(forNode) - .withChild(Ast(idxLocal)) - .withChild(idxInitializerAst) - .withChild(compareAst) - .withChild(incrementAst) - .withChild(bodyAst) - .withConditionEdges(forNode, compareAst.root.toList) - - tempIterableInitAsts ++ Seq(forAst) - end astsForNativeForEach - - private def iteratorLocalForForEach(lineNumber: Option[Integer]): NewLocal = - val iteratorLocalName = nextIterableName() + val iterableAssignArgs = List(Ast(iterableAssignIdentifier), iterableAst) + val iterableAssignAst = + callAst(iterableAssignNode, iterableAssignArgs) + .withRefEdge(iterableAssignIdentifier, iterableLocalNode) + + ( + NodeTypeInfo( + iterableLocalNode, + iterableLocalNode.name, + Some(iterableLocalNode.typeFullName) + ), + List(iterableLocalAst, iterableAssignAst) + ) + end iterableAssignAstsForNativeForEach + + private def nativeForEachIdxLocalNode(lineNo: Option[Integer]): NewLocal = + val idxName = nextIndexName() + val typeFullName = TypeConstants.Int + val idxLocal = NewLocal() - .name(iteratorLocalName) - .code(iteratorLocalName) - .typeFullName(TypeConstants.Iterator) - .lineNumber(lineNumber) - - private def iteratorAssignAstForForEach( - iterExpr: Expression, - iteratorLocalNode: NewLocal, - iterableType: Option[String], - lineNo: Option[Integer] - ): Ast = - val iteratorAssignNode = - newOperatorCallNode( - Operators.assignment, - code = "", - typeFullName = Some(TypeConstants.Iterator), - line = lineNo - ) - val iteratorAssignIdentifier = - identifierNode( - iterExpr, - iteratorLocalNode.name, - iteratorLocalNode.name, - iteratorLocalNode.typeFullName - ) - - val iteratorCallNode = - newCallNode( - "iterator", - iterableType, - TypeConstants.Iterator, - DispatchTypes.DYNAMIC_DISPATCH, - lineNumber = lineNo - ) - - val actualIteratorAst = - astsForExpression(iterExpr, expectedType = ExpectedType.empty).toList match - case Nil => - logger.debug(s"Could not create receiver ast for iterator $iterExpr") - None - - case ast :: Nil => Some(ast) - - case ast :: _ => - logger.debug( - s"Created multiple receiver asts for $iterExpr. Dropping all but the first." - ) - Some(ast) - - val iteratorCallAst = - callAst(iteratorCallNode, base = actualIteratorAst) - - callAst(iteratorAssignNode, List(Ast(iteratorAssignIdentifier), iteratorCallAst)) - .withRefEdge(iteratorAssignIdentifier, iteratorLocalNode) - end iteratorAssignAstForForEach - - private def hasNextCallAstForForEach( - iteratorLocalNode: NewLocal, - lineNo: Option[Integer] - ): Ast = - val iteratorHasNextCallNode = - newCallNode( - "hasNext", - Some(TypeConstants.Iterator), - TypeConstants.Boolean, - DispatchTypes.DYNAMIC_DISPATCH, - lineNumber = lineNo - ) - val iteratorHasNextCallReceiver = - newIdentifierNode(iteratorLocalNode.name, iteratorLocalNode.typeFullName) - - callAst(iteratorHasNextCallNode, base = Some(Ast(iteratorHasNextCallReceiver))) - .withRefEdge(iteratorHasNextCallReceiver, iteratorLocalNode) - - private def astForIterableForEachItemAssign( - iteratorLocalNode: NewLocal, - variableLocal: NewLocal - ): Ast = - val lineNo = variableLocal.lineNumber - val forVariableType = variableLocal.typeFullName - val varLocalAssignNode = - newOperatorCallNode( - Operators.assignment, - PropertyDefaults.Code, - Some(forVariableType), - lineNo - ) - val varLocalAssignIdentifier = - newIdentifierNode(variableLocal.name, variableLocal.typeFullName) - - val iterNextCallNode = - newCallNode( - "next", - Some(TypeConstants.Iterator), - TypeConstants.Object, - DispatchTypes.DYNAMIC_DISPATCH, - lineNumber = lineNo - ) - val iterNextCallReceiver = - newIdentifierNode(iteratorLocalNode.name, iteratorLocalNode.typeFullName) - val iterNextCallAst = - callAst(iterNextCallNode, base = Some(Ast(iterNextCallReceiver))) - .withRefEdge(iterNextCallReceiver, iteratorLocalNode) - - callAst(varLocalAssignNode, List(Ast(varLocalAssignIdentifier), iterNextCallAst)) - .withRefEdge(varLocalAssignIdentifier, variableLocal) - end astForIterableForEachItemAssign - - private def astForIterableForEach(stmt: ForEachStmt, iterableType: Option[String]): Seq[Ast] = - val lineNo = line(stmt) - - val iteratorLocalNode = iteratorLocalForForEach(lineNo) - val iteratorAssignAst = - iteratorAssignAstForForEach(stmt.getIterable, iteratorLocalNode, iterableType, lineNo) - val iteratorHasNextCallAst = hasNextCallAstForForEach(iteratorLocalNode, lineNo) - val variableLocal = variableLocalForForEachBody(stmt) - val variableAssignAst = astForIterableForEachItemAssign(iteratorLocalNode, variableLocal) - - val bodyPrefixAsts = Seq(Ast(variableLocal), variableAssignAst) - val bodyAst = stmt.getBody match - case block: BlockStmt => - astForBlockStatement(block, prefixAsts = bodyPrefixAsts) - - case bodyStmt => - val bodyBlockNode = NewBlock().lineNumber(lineNo) - val bodyStmtAsts = astsForStatement(bodyStmt) - Ast(bodyBlockNode) - .withChildren(bodyPrefixAsts) - .withChildren(bodyStmtAsts) - - val forNode = - NewControlStructure() - .controlStructureType(ControlStructureTypes.WHILE) - .code(ControlStructureTypes.FOR) - .lineNumber(lineNo) - .columnNumber(column(stmt)) - - val forAst = controlStructureAst(forNode, Some(iteratorHasNextCallAst), List(bodyAst)) - - Seq(Ast(iteratorLocalNode), iteratorAssignAst, forAst) - end astForIterableForEach - - private def astForForEach(stmt: ForEachStmt): Seq[Ast] = - scope.pushBlockScope() - - val ast = expressionReturnTypeFullName(stmt.getIterable) match - case Some(typeFullName) if typeFullName.endsWith("[]") => - astsForNativeForEach(stmt, Some(typeFullName)) - - case maybeType => - astForIterableForEach(stmt, maybeType) - - scope.popScope() - ast + .name(idxName) + .typeFullName(typeFullName) + .code(idxName) + .lineNumber(lineNo) + scope.addLocal(idxLocal) + idxLocal + + private def nativeForEachIdxInitializerAst(lineNo: Option[Integer], idxLocal: NewLocal): Ast = + val idxName = idxLocal.name + val idxInitializerCallNode = newOperatorCallNode( + Operators.assignment, + code = s"int $idxName = 0", + line = lineNo, + typeFullName = Some(TypeConstants.Int) + ) + val idxIdentifierArg = newIdentifierNode(idxName, idxLocal.typeFullName) + val zeroLiteral = + NewLiteral() + .code("0") + .typeFullName(TypeConstants.Int) + .lineNumber(lineNo) + val idxInitializerArgAsts = List(Ast(idxIdentifierArg), Ast(zeroLiteral)) + callAst(idxInitializerCallNode, idxInitializerArgAsts) + .withRefEdge(idxIdentifierArg, idxLocal) + + private def nativeForEachCompareAst( + lineNo: Option[Integer], + iterableSource: NodeTypeInfo, + idxLocal: NewLocal + ): Ast = + val idxName = idxLocal.name + + val compareNode = newOperatorCallNode( + Operators.lessThan, + code = s"$idxName < ${iterableSource.name}.length", + typeFullName = Some(TypeConstants.Boolean), + line = lineNo + ) + val comparisonIdxIdentifier = newIdentifierNode(idxName, idxLocal.typeFullName) + val comparisonFieldAccess = newOperatorCallNode( + Operators.fieldAccess, + code = s"${iterableSource.name}.length", + typeFullName = Some(TypeConstants.Int), + line = lineNo + ) + val fieldAccessIdentifier = + newIdentifierNode(iterableSource.name, iterableSource.typeFullName.getOrElse("ANY")) + val fieldAccessFieldIdentifier = newFieldIdentifierNode("length", lineNo) + val fieldAccessArgs = List(fieldAccessIdentifier, fieldAccessFieldIdentifier).map(Ast(_)) + val fieldAccessAst = callAst(comparisonFieldAccess, fieldAccessArgs) + val compareArgs = List(Ast(comparisonIdxIdentifier), fieldAccessAst) + + // TODO: This is a workaround for a crash when looping over statically imported members. Handle those properly. + val iterableSourceNode = localParamOrMemberFromNode(iterableSource) + + callAst(compareNode, compareArgs) + .withRefEdge(comparisonIdxIdentifier, idxLocal) + .withRefEdges(fieldAccessIdentifier, iterableSourceNode.toList) + end nativeForEachCompareAst + + private def nativeForEachIncrementAst(lineNo: Option[Integer], idxLocal: NewLocal): Ast = + val incrementNode = newOperatorCallNode( + Operators.postIncrement, + code = s"${idxLocal.name}++", + typeFullName = Some(TypeConstants.Int), + line = lineNo + ) + val incrementArg = newIdentifierNode(idxLocal.name, idxLocal.typeFullName) + val incrementArgAst = Ast(incrementArg) + callAst(incrementNode, List(incrementArgAst)) + .withRefEdge(incrementArg, idxLocal) + + private def variableLocalForForEachBody(stmt: ForEachStmt): NewLocal = + val lineNo = line(stmt) + // Create item local + val maybeVariable = stmt.getVariable.getVariables.asScala.toList match + case Nil => + logger.debug(s"ForEach statement has empty variable list: $filename$lineNo") + None + case variable :: Nil => Some(variable) + case variable :: _ => + logger.debug( + s"ForEach statement defines multiple variables. Dropping all but the first: $filename$lineNo" + ) + Some(variable) + + val partialLocalNode = NewLocal().lineNumber(lineNo) + + maybeVariable match + case Some(variable) => + val name = variable.getNameAsString + val typeFullName = typeInfoCalc.fullName(variable.getType).getOrElse("ANY") + val localNode = partialLocalNode + .name(name) + .code(variable.getNameAsString) + .typeFullName(typeFullName) + + scope.addLocal(localNode) + localNode + + case None => + // Returning partialLocalNode here is fine since getting to this case means everything is broken anyways :) + partialLocalNode + end variableLocalForForEachBody + + private def localParamOrMemberFromNode(nodeTypeInfo: NodeTypeInfo): Option[NewNode] = + nodeTypeInfo.node match + case localNode: NewLocal => Some(localNode) + case memberNode: NewMember => Some(memberNode) + case parameterNode: NewMethodParameterIn => Some(parameterNode) + case _ => None + private def variableAssignForNativeForEachBody( + variableLocal: NewLocal, + idxLocal: NewLocal, + iterable: NodeTypeInfo + ): Ast = + // Everything will be on the same line as the `for` statement, but this is the most useful + // solution for debugging. + val lineNo = variableLocal.lineNumber + val varAssignNode = + newOperatorCallNode( + Operators.assignment, + PropertyDefaults.Code, + Some(variableLocal.typeFullName), + lineNo + ) - private def astForSwitchStatement(stmt: SwitchStmt): Ast = - val switchNode = - NewControlStructure() - .controlStructureType(ControlStructureTypes.SWITCH) - .code(s"switch(${stmt.getSelector.toString})") + val targetNode = newIdentifierNode(variableLocal.name, variableLocal.typeFullName) - val selectorAsts = astsForExpression(stmt.getSelector, ExpectedType.empty) - val selectorNode = selectorAsts.head.root.get + val indexAccessTypeFullName = iterable.typeFullName.map(_.replaceAll(raw"\[]", "")) + val indexAccess = + newOperatorCallNode( + Operators.indexAccess, + PropertyDefaults.Code, + indexAccessTypeFullName, + lineNo + ) - val entryAsts = stmt.getEntries.asScala.flatMap(astForSwitchEntry) + val indexAccessIdentifier = + newIdentifierNode(iterable.name, iterable.typeFullName.getOrElse("ANY")) + val indexAccessIndex = newIdentifierNode(idxLocal.name, idxLocal.typeFullName) + + val indexAccessArgsAsts = List(indexAccessIdentifier, indexAccessIndex).map(Ast(_)) + val indexAccessAst = callAst(indexAccess, indexAccessArgsAsts) + + val iterableSourceNode = localParamOrMemberFromNode(iterable) + + val assignArgsAsts = List(Ast(targetNode), indexAccessAst) + callAst(varAssignNode, assignArgsAsts) + .withRefEdge(targetNode, variableLocal) + .withRefEdges(indexAccessIdentifier, iterableSourceNode.toList) + .withRefEdge(indexAccessIndex, idxLocal) + end variableAssignForNativeForEachBody + + private def nativeForEachBodyAst( + stmt: ForEachStmt, + idxLocal: NewLocal, + iterable: NodeTypeInfo + ): Ast = + val variableLocal = variableLocalForForEachBody(stmt) + val variableLocalAst = Ast(variableLocal) + val variableAssignAst = + variableAssignForNativeForEachBody(variableLocal, idxLocal, iterable) + + stmt.getBody match + case block: BlockStmt => + astForBlockStatement(block, prefixAsts = List(variableLocalAst, variableAssignAst)) + case statement => + val stmtAsts = astsForStatement(statement) + val blockNode = NewBlock().lineNumber(variableLocal.lineNumber) + Ast(blockNode) + .withChild(variableLocalAst) + .withChild(variableAssignAst) + .withChildren(stmtAsts) + end nativeForEachBodyAst + + private def astsForNativeForEach(stmt: ForEachStmt, iterableType: Option[String]): Seq[Ast] = + + // This is ugly, but for a case like `for (int x : new int[] { ... })` this creates a new LOCAL + // with the assignment `int[] $iterLocal0 = new int[] { ... }` before the FOR loop. + // TODO: Fix this + val (iterableSource: NodeTypeInfo, tempIterableInitAsts) = stmt.getIterable match + case nameExpr: NameExpr => + scope.lookupVariable(nameExpr.getNameAsString).asNodeInfoOption match + // If this is not the case, then the code is broken (iterable not in scope). + case Some(nodeTypeInfo) => (nodeTypeInfo, Nil) + case _ => iterableAssignAstsForNativeForEach(nameExpr, iterableType) + case iterableExpr => iterableAssignAstsForNativeForEach(iterableExpr, iterableType) + + val forNode = NewControlStructure() + .controlStructureType(ControlStructureTypes.FOR) + + val lineNo = line(stmt) + + val idxLocal = nativeForEachIdxLocalNode(lineNo) + val idxInitializerAst = nativeForEachIdxInitializerAst(lineNo, idxLocal) + // TODO next: pass NodeTypeInfo around + val compareAst = nativeForEachCompareAst(lineNo, iterableSource, idxLocal) + val incrementAst = nativeForEachIncrementAst(lineNo, idxLocal) + val bodyAst = nativeForEachBodyAst(stmt, idxLocal, iterableSource) + + val forAst = Ast(forNode) + .withChild(Ast(idxLocal)) + .withChild(idxInitializerAst) + .withChild(compareAst) + .withChild(incrementAst) + .withChild(bodyAst) + .withConditionEdges(forNode, compareAst.root.toList) + + tempIterableInitAsts ++ Seq(forAst) + end astsForNativeForEach + + private def iteratorLocalForForEach(lineNumber: Option[Integer]): NewLocal = + val iteratorLocalName = nextIterableName() + NewLocal() + .name(iteratorLocalName) + .code(iteratorLocalName) + .typeFullName(TypeConstants.Iterator) + .lineNumber(lineNumber) + + private def iteratorAssignAstForForEach( + iterExpr: Expression, + iteratorLocalNode: NewLocal, + iterableType: Option[String], + lineNo: Option[Integer] + ): Ast = + val iteratorAssignNode = + newOperatorCallNode( + Operators.assignment, + code = "", + typeFullName = Some(TypeConstants.Iterator), + line = lineNo + ) + val iteratorAssignIdentifier = + identifierNode( + iterExpr, + iteratorLocalNode.name, + iteratorLocalNode.name, + iteratorLocalNode.typeFullName + ) - val switchBodyAst = Ast(NewBlock()).withChildren(entryAsts) + val iteratorCallNode = + newCallNode( + "iterator", + iterableType, + TypeConstants.Iterator, + DispatchTypes.DYNAMIC_DISPATCH, + lineNumber = lineNo + ) - Ast(switchNode) - .withChildren(selectorAsts) - .withChild(switchBodyAst) - .withConditionEdge(switchNode, selectorNode) + val actualIteratorAst = + astsForExpression(iterExpr, expectedType = ExpectedType.empty).toList match + case Nil => + logger.debug(s"Could not create receiver ast for iterator $iterExpr") + None - private def astForSynchronizedStatement(stmt: SynchronizedStmt): Ast = - val parentNode = - NewBlock() - .lineNumber(line(stmt)) - .columnNumber(column(stmt)) + case ast :: Nil => Some(ast) - val modifier = Ast(newModifierNode("SYNCHRONIZED")) + case ast :: _ => + logger.debug( + s"Created multiple receiver asts for $iterExpr. Dropping all but the first." + ) + Some(ast) + + val iteratorCallAst = + callAst(iteratorCallNode, base = actualIteratorAst) + + callAst(iteratorAssignNode, List(Ast(iteratorAssignIdentifier), iteratorCallAst)) + .withRefEdge(iteratorAssignIdentifier, iteratorLocalNode) + end iteratorAssignAstForForEach + + private def hasNextCallAstForForEach( + iteratorLocalNode: NewLocal, + lineNo: Option[Integer] + ): Ast = + val iteratorHasNextCallNode = + newCallNode( + "hasNext", + Some(TypeConstants.Iterator), + TypeConstants.Boolean, + DispatchTypes.DYNAMIC_DISPATCH, + lineNumber = lineNo + ) + val iteratorHasNextCallReceiver = + newIdentifierNode(iteratorLocalNode.name, iteratorLocalNode.typeFullName) + + callAst(iteratorHasNextCallNode, base = Some(Ast(iteratorHasNextCallReceiver))) + .withRefEdge(iteratorHasNextCallReceiver, iteratorLocalNode) + + private def astForIterableForEachItemAssign( + iteratorLocalNode: NewLocal, + variableLocal: NewLocal + ): Ast = + val lineNo = variableLocal.lineNumber + val forVariableType = variableLocal.typeFullName + val varLocalAssignNode = + newOperatorCallNode( + Operators.assignment, + PropertyDefaults.Code, + Some(forVariableType), + lineNo + ) + val varLocalAssignIdentifier = + newIdentifierNode(variableLocal.name, variableLocal.typeFullName) + + val iterNextCallNode = + newCallNode( + "next", + Some(TypeConstants.Iterator), + TypeConstants.Object, + DispatchTypes.DYNAMIC_DISPATCH, + lineNumber = lineNo + ) + val iterNextCallReceiver = + newIdentifierNode(iteratorLocalNode.name, iteratorLocalNode.typeFullName) + val iterNextCallAst = + callAst(iterNextCallNode, base = Some(Ast(iterNextCallReceiver))) + .withRefEdge(iterNextCallReceiver, iteratorLocalNode) + + callAst(varLocalAssignNode, List(Ast(varLocalAssignIdentifier), iterNextCallAst)) + .withRefEdge(varLocalAssignIdentifier, variableLocal) + end astForIterableForEachItemAssign + + private def astForIterableForEach(stmt: ForEachStmt, iterableType: Option[String]): Seq[Ast] = + val lineNo = line(stmt) + + val iteratorLocalNode = iteratorLocalForForEach(lineNo) + val iteratorAssignAst = + iteratorAssignAstForForEach(stmt.getIterable, iteratorLocalNode, iterableType, lineNo) + val iteratorHasNextCallAst = hasNextCallAstForForEach(iteratorLocalNode, lineNo) + val variableLocal = variableLocalForForEachBody(stmt) + val variableAssignAst = astForIterableForEachItemAssign(iteratorLocalNode, variableLocal) + + val bodyPrefixAsts = Seq(Ast(variableLocal), variableAssignAst) + val bodyAst = stmt.getBody match + case block: BlockStmt => + astForBlockStatement(block, prefixAsts = bodyPrefixAsts) + + case bodyStmt => + val bodyBlockNode = NewBlock().lineNumber(lineNo) + val bodyStmtAsts = astsForStatement(bodyStmt) + Ast(bodyBlockNode) + .withChildren(bodyPrefixAsts) + .withChildren(bodyStmtAsts) + + val forNode = + NewControlStructure() + .controlStructureType(ControlStructureTypes.WHILE) + .code(ControlStructureTypes.FOR) + .lineNumber(lineNo) + .columnNumber(column(stmt)) - val exprAsts = astsForExpression(stmt.getExpression, ExpectedType.empty) - val bodyAst = astForBlockStatement(stmt.getBody) + val forAst = controlStructureAst(forNode, Some(iteratorHasNextCallAst), List(bodyAst)) - Ast(parentNode) - .withChild(modifier) - .withChildren(exprAsts) - .withChild(bodyAst) + Seq(Ast(iteratorLocalNode), iteratorAssignAst, forAst) + end astForIterableForEach - private def astsForSwitchCases(entry: SwitchEntry): Seq[Ast] = - entry.getLabels.asScala.toList match - case Nil => - val target = NewJumpTarget() - .name("default") - .code("default") - Seq(Ast(target)) + private def astForForEach(stmt: ForEachStmt): Seq[Ast] = + scope.pushBlockScope() - case labels => - labels.flatMap { label => - val jumpTarget = NewJumpTarget() - .name("case") - .code(label.toString) - val labelAsts = astsForExpression(label, ExpectedType.empty).toList + val ast = expressionReturnTypeFullName(stmt.getIterable) match + case Some(typeFullName) if typeFullName.endsWith("[]") => + astsForNativeForEach(stmt, Some(typeFullName)) - Ast(jumpTarget) :: labelAsts - } + case maybeType => + astForIterableForEach(stmt, maybeType) - private def astForSwitchEntry(entry: SwitchEntry): Seq[Ast] = - val labelAsts = astsForSwitchCases(entry) + scope.popScope() + ast - val statementAsts = entry.getStatements.asScala.flatMap(astsForStatement) + private def astForSwitchStatement(stmt: SwitchStmt): Ast = + val switchNode = + NewControlStructure() + .controlStructureType(ControlStructureTypes.SWITCH) + .code(s"switch(${stmt.getSelector.toString})") - labelAsts ++ statementAsts + val selectorAsts = astsForExpression(stmt.getSelector, ExpectedType.empty) + val selectorNode = selectorAsts.head.root.get - private def astForAssertStatement(stmt: AssertStmt): Ast = - val callNode = NewCall() - .name("assert") - .methodFullName("assert") - .dispatchType(DispatchTypes.STATIC_DISPATCH) - .code(stmt.toString) - .lineNumber(line(stmt)) - .columnNumber(column(stmt)) + val entryAsts = stmt.getEntries.asScala.flatMap(astForSwitchEntry) - val args = astsForExpression(stmt.getCheck, ExpectedType.Boolean) - callAst(callNode, args) + val switchBodyAst = Ast(NewBlock()).withChildren(entryAsts) - private def astForBlockStatement( - stmt: BlockStmt, - codeStr: String = "", - prefixAsts: Seq[Ast] = Seq.empty - ): Ast = + Ast(switchNode) + .withChildren(selectorAsts) + .withChild(switchBodyAst) + .withConditionEdge(switchNode, selectorNode) - val block = NewBlock() - .code(codeStr) + private def astForSynchronizedStatement(stmt: SynchronizedStmt): Ast = + val parentNode = + NewBlock() .lineNumber(line(stmt)) .columnNumber(column(stmt)) - scope.pushBlockScope() + val modifier = Ast(newModifierNode("SYNCHRONIZED")) - val stmtAsts = stmt.getStatements.asScala.flatMap(astsForStatement) + val exprAsts = astsForExpression(stmt.getExpression, ExpectedType.empty) + val bodyAst = astForBlockStatement(stmt.getBody) - scope.popScope() - Ast(block) - .withChildren(prefixAsts) - .withChildren(stmtAsts) - end astForBlockStatement - - private def astForReturnNode(ret: ReturnStmt): Ast = - val returnNode = NewReturn() - .lineNumber(line(ret)) - .columnNumber(column(ret)) - .code(ret.toString) - if ret.getExpression.isPresent then - val expectedType = scope.enclosingMethodReturnType.getOrElse(ExpectedType.empty) - val exprAsts = astsForExpression(ret.getExpression.get(), expectedType) - returnAst(returnNode, exprAsts) - else - Ast(returnNode) + Ast(parentNode) + .withChild(modifier) + .withChildren(exprAsts) + .withChild(bodyAst) - private def astForUnaryExpr(expr: UnaryExpr, expectedType: ExpectedType): Ast = - val operatorName = expr.getOperator match - case UnaryExpr.Operator.LOGICAL_COMPLEMENT => Operators.logicalNot - case UnaryExpr.Operator.POSTFIX_DECREMENT => Operators.postDecrement - case UnaryExpr.Operator.POSTFIX_INCREMENT => Operators.postIncrement - case UnaryExpr.Operator.PREFIX_DECREMENT => Operators.preDecrement - case UnaryExpr.Operator.PREFIX_INCREMENT => Operators.preIncrement - case UnaryExpr.Operator.BITWISE_COMPLEMENT => Operators.not - case UnaryExpr.Operator.PLUS => Operators.plus - case UnaryExpr.Operator.MINUS => Operators.minus + private def astsForSwitchCases(entry: SwitchEntry): Seq[Ast] = + entry.getLabels.asScala.toList match + case Nil => + val target = NewJumpTarget() + .name("default") + .code("default") + Seq(Ast(target)) - val argsAsts = astsForExpression(expr.getExpression, expectedType) - - val typeFullName = - expressionReturnTypeFullName(expr) - .orElse(argsAsts.headOption.flatMap(_.rootType)) - .orElse(expectedType.fullName) - .getOrElse(TypeConstants.Any) + case labels => + labels.flatMap { label => + val jumpTarget = NewJumpTarget() + .name("case") + .code(label.toString) + val labelAsts = astsForExpression(label, ExpectedType.empty).toList - val callNode = newOperatorCallNode( - operatorName, - code = expr.toString, - typeFullName = Some(typeFullName), - line = line(expr), - column = column(expr) - ) + Ast(jumpTarget) :: labelAsts + } - callAst(callNode, argsAsts) - end astForUnaryExpr + private def astForSwitchEntry(entry: SwitchEntry): Seq[Ast] = + val labelAsts = astsForSwitchCases(entry) + + val statementAsts = entry.getStatements.asScala.flatMap(astsForStatement) + + labelAsts ++ statementAsts + + private def astForAssertStatement(stmt: AssertStmt): Ast = + val callNode = NewCall() + .name("assert") + .methodFullName("assert") + .dispatchType(DispatchTypes.STATIC_DISPATCH) + .code(stmt.toString) + .lineNumber(line(stmt)) + .columnNumber(column(stmt)) + + val args = astsForExpression(stmt.getCheck, ExpectedType.Boolean) + callAst(callNode, args) + + private def astForBlockStatement( + stmt: BlockStmt, + codeStr: String = "", + prefixAsts: Seq[Ast] = Seq.empty + ): Ast = + + val block = NewBlock() + .code(codeStr) + .lineNumber(line(stmt)) + .columnNumber(column(stmt)) + + scope.pushBlockScope() + + val stmtAsts = stmt.getStatements.asScala.flatMap(astsForStatement) + + scope.popScope() + Ast(block) + .withChildren(prefixAsts) + .withChildren(stmtAsts) + end astForBlockStatement + + private def astForReturnNode(ret: ReturnStmt): Ast = + val returnNode = NewReturn() + .lineNumber(line(ret)) + .columnNumber(column(ret)) + .code(ret.toString) + if ret.getExpression.isPresent then + val expectedType = scope.enclosingMethodReturnType.getOrElse(ExpectedType.empty) + val exprAsts = astsForExpression(ret.getExpression.get(), expectedType) + returnAst(returnNode, exprAsts) + else + Ast(returnNode) + + private def astForUnaryExpr(expr: UnaryExpr, expectedType: ExpectedType): Ast = + val operatorName = expr.getOperator match + case UnaryExpr.Operator.LOGICAL_COMPLEMENT => Operators.logicalNot + case UnaryExpr.Operator.POSTFIX_DECREMENT => Operators.postDecrement + case UnaryExpr.Operator.POSTFIX_INCREMENT => Operators.postIncrement + case UnaryExpr.Operator.PREFIX_DECREMENT => Operators.preDecrement + case UnaryExpr.Operator.PREFIX_INCREMENT => Operators.preIncrement + case UnaryExpr.Operator.BITWISE_COMPLEMENT => Operators.not + case UnaryExpr.Operator.PLUS => Operators.plus + case UnaryExpr.Operator.MINUS => Operators.minus + + val argsAsts = astsForExpression(expr.getExpression, expectedType) + + val typeFullName = + expressionReturnTypeFullName(expr) + .orElse(argsAsts.headOption.flatMap(_.rootType)) + .orElse(expectedType.fullName) + .getOrElse(TypeConstants.Any) + + val callNode = newOperatorCallNode( + operatorName, + code = expr.toString, + typeFullName = Some(typeFullName), + line = line(expr), + column = column(expr) + ) + + callAst(callNode, argsAsts) + end astForUnaryExpr + + private def astForArrayAccessExpr(expr: ArrayAccessExpr, expectedType: ExpectedType): Ast = + val typeFullName = + expressionReturnTypeFullName(expr) + .orElse(expectedType.fullName) + .getOrElse(TypeConstants.Any) + val callNode = newOperatorCallNode( + Operators.indexAccess, + code = expr.toString, + typeFullName = Some(typeFullName), + line = line(expr), + column = column(expr) + ) + + val arrayExpectedType = expectedType.copy(fullName = expectedType.fullName.map(_ ++ "[]")) + val nameAst = astsForExpression(expr.getName, arrayExpectedType) + val indexAst = astsForExpression(expr.getIndex, ExpectedType.Int) + val args = nameAst ++ indexAst + callAst(callNode, args) + + private def astForArrayCreationExpr(expr: ArrayCreationExpr, expectedType: ExpectedType): Ast = + val elementType = tryWithSafeStackOverflow(expr.getElementType.resolve()).map(elementType => + ExpectedType(typeInfoCalc.fullName(elementType).map(_ ++ "[]"), Option(elementType)) + ) + val maybeInitializerAst = + expr.getInitializer.toScala.map(astForArrayInitializerExpr( + _, + elementType.getOrElse(expectedType) + )) - private def astForArrayAccessExpr(expr: ArrayAccessExpr, expectedType: ExpectedType): Ast = - val typeFullName = - expressionReturnTypeFullName(expr) - .orElse(expectedType.fullName) - .getOrElse(TypeConstants.Any) + maybeInitializerAst.flatMap(_.root) match + case Some(initializerRoot: NewCall) => initializerRoot.code(expr.toString) + case _ => // This should never happen + maybeInitializerAst.getOrElse { + val typeFullName = expressionReturnTypeFullName(expr).orElse( + expectedType.fullName + ).getOrElse(TypeConstants.Any) val callNode = newOperatorCallNode( - Operators.indexAccess, + Operators.alloc, code = expr.toString, - typeFullName = Some(typeFullName), - line = line(expr), - column = column(expr) + typeFullName = Some(typeFullName) ) - - val arrayExpectedType = expectedType.copy(fullName = expectedType.fullName.map(_ ++ "[]")) - val nameAst = astsForExpression(expr.getName, arrayExpectedType) - val indexAst = astsForExpression(expr.getIndex, ExpectedType.Int) - val args = nameAst ++ indexAst - callAst(callNode, args) - - private def astForArrayCreationExpr(expr: ArrayCreationExpr, expectedType: ExpectedType): Ast = - val maybeInitializerAst = - expr.getInitializer.toScala.map(astForArrayInitializerExpr(_, expectedType)) - - maybeInitializerAst.flatMap(_.root) match - case Some(initializerRoot: NewCall) => initializerRoot.code(expr.toString) - case _ => // This should never happen - maybeInitializerAst.getOrElse { - val typeFullName = expressionReturnTypeFullName(expr).orElse( - expectedType.fullName - ).getOrElse(TypeConstants.Any) - val callNode = newOperatorCallNode( - Operators.alloc, - code = expr.toString, - typeFullName = Some(typeFullName) - ) - val levelAsts = expr.getLevels.asScala.flatMap { lvl => - lvl.getDimension.toScala match - case Some(dimension) => astsForExpression(dimension, ExpectedType.Int) - - case None => Seq.empty - }.toSeq - callAst(callNode, levelAsts) - } - end astForArrayCreationExpr - - private def astForArrayInitializerExpr( - expr: ArrayInitializerExpr, - expectedType: ExpectedType - ): Ast = - val typeFullName = - expressionReturnTypeFullName(expr) - .orElse(expectedType.fullName) - .getOrElse(TypeConstants.Any) - val callNode = newOperatorCallNode( - Operators.arrayInitializer, - code = expr.toString, - typeFullName = Some(typeFullName), - line = line(expr), - column = column(expr) + val levelAsts = expr.getLevels.asScala.flatMap { lvl => + lvl.getDimension.toScala match + case Some(dimension) => astsForExpression(dimension, ExpectedType.Int) + + case None => Seq.empty + }.toSeq + callAst(callNode, levelAsts) + } + end astForArrayCreationExpr + + private def astForArrayInitializerExpr( + expr: ArrayInitializerExpr, + expectedType: ExpectedType + ): Ast = + val typeFullName = expectedType.fullName + .map(typeInfoCalc.registerType) + .getOrElse(TypeConstants.Any) + val callNode = newOperatorCallNode( + Operators.arrayInitializer, + code = expr.toString, + typeFullName = Some(typeFullName), + line = line(expr), + column = column(expr) + ) + + val MAX_INITIALIZERS = 1000 + + val expectedValueType = expr.getValues.asScala.headOption.map { value => + // typeName and resolvedType may represent different types since typeName can fall + // back to known information or primitive types. While this certainly isn't ideal, + // it shouldn't cause issues since resolvedType is only used where the extra type + // information not available in typeName is necessary. + val typeName = expressionReturnTypeFullName(value) + val resolvedType = tryWithSafeStackOverflow(value.calculateResolvedType()).toOption + ExpectedType(typeName, resolvedType) + } + val args = expr.getValues.asScala + .slice(0, MAX_INITIALIZERS) + .flatMap(astsForExpression(_, expectedValueType.getOrElse(ExpectedType.empty))) + .toSeq + + val ast = callAst(callNode, args) + + if expr.getValues.size() > MAX_INITIALIZERS then + val placeholder = NewLiteral() + .typeFullName(TypeConstants.Any) + .code("") + .lineNumber(line(expr)) + .columnNumber(column(expr)) + ast.withChild(Ast(placeholder)).withArgEdge(callNode, placeholder) + else + ast + end astForArrayInitializerExpr + + def astForBinaryExpr(expr: BinaryExpr, expectedType: ExpectedType): Ast = + val operatorName = expr.getOperator match + case BinaryExpr.Operator.OR => Operators.logicalOr + case BinaryExpr.Operator.AND => Operators.logicalAnd + case BinaryExpr.Operator.BINARY_OR => Operators.or + case BinaryExpr.Operator.BINARY_AND => Operators.and + case BinaryExpr.Operator.DIVIDE => Operators.division + case BinaryExpr.Operator.EQUALS => Operators.equals + case BinaryExpr.Operator.GREATER => Operators.greaterThan + case BinaryExpr.Operator.GREATER_EQUALS => Operators.greaterEqualsThan + case BinaryExpr.Operator.LESS => Operators.lessThan + case BinaryExpr.Operator.LESS_EQUALS => Operators.lessEqualsThan + case BinaryExpr.Operator.LEFT_SHIFT => Operators.shiftLeft + case BinaryExpr.Operator.SIGNED_RIGHT_SHIFT => Operators.logicalShiftRight + case BinaryExpr.Operator.UNSIGNED_RIGHT_SHIFT => Operators.arithmeticShiftRight + case BinaryExpr.Operator.XOR => Operators.xor + case BinaryExpr.Operator.NOT_EQUALS => Operators.notEquals + case BinaryExpr.Operator.PLUS => Operators.addition + case BinaryExpr.Operator.MINUS => Operators.subtraction + case BinaryExpr.Operator.MULTIPLY => Operators.multiplication + case BinaryExpr.Operator.REMAINDER => Operators.modulo + + val args = + astsForExpression(expr.getLeft, expectedType) ++ astsForExpression( + expr.getRight, + expectedType ) - val MAX_INITIALIZERS = 1000 - - val expectedValueType = expr.getValues.asScala.headOption.map { value => - // typeName and resolvedType may represent different types since typeName can fall - // back to known information or primitive types. While this certainly isn't ideal, - // it shouldn't cause issues since resolvedType is only used where the extra type - // information not available in typeName is necessary. - val typeName = expressionReturnTypeFullName(value) - val resolvedType = tryWithSafeStackOverflow(value.calculateResolvedType()).toOption - ExpectedType(typeName, resolvedType) + val typeFullName = + expressionReturnTypeFullName(expr) + .orElse(args.headOption.flatMap(_.rootType)) + .orElse(args.lastOption.flatMap(_.rootType)) + .orElse(expectedType.fullName) + .getOrElse(TypeConstants.Any) + + val callNode = newOperatorCallNode( + operatorName, + code = expr.toString, + typeFullName = Some(typeFullName), + line = line(expr), + column = column(expr) + ) + + callAst(callNode, args) + end astForBinaryExpr + + private def astForCastExpr(expr: CastExpr, expectedType: ExpectedType): Ast = + val typeFullName = + typeInfoCalc + .fullName(expr.getType) + .orElse(expectedType.fullName) + .getOrElse(TypeConstants.Any) + + val callNode = newOperatorCallNode( + Operators.cast, + code = expr.toString, + typeFullName = Some(typeFullName), + line = line(expr), + column = column(expr) + ) + + val typeNode = NewTypeRef() + .code(expr.getType.toString) + .lineNumber(line(expr)) + .columnNumber(column(expr)) + .typeFullName(typeFullName) + val typeAst = Ast(typeNode) + + val exprAst = astsForExpression(expr.getExpression, ExpectedType.empty) + + callAst(callNode, Seq(typeAst) ++ exprAst) + end astForCastExpr + + private def astsForAssignExpr(expr: AssignExpr, expectedExprType: ExpectedType): Seq[Ast] = + val operatorName = expr.getOperator match + case Operator.ASSIGN => Operators.assignment + case Operator.PLUS => Operators.assignmentPlus + case Operator.MINUS => Operators.assignmentMinus + case Operator.MULTIPLY => Operators.assignmentMultiplication + case Operator.DIVIDE => Operators.assignmentDivision + case Operator.BINARY_AND => Operators.assignmentAnd + case Operator.BINARY_OR => Operators.assignmentOr + case Operator.XOR => Operators.assignmentXor + case Operator.REMAINDER => Operators.assignmentModulo + case Operator.LEFT_SHIFT => Operators.assignmentShiftLeft + case Operator.SIGNED_RIGHT_SHIFT => Operators.assignmentArithmeticShiftRight + case Operator.UNSIGNED_RIGHT_SHIFT => Operators.assignmentLogicalShiftRight + + val maybeResolvedType = Try(expr.getTarget.calculateResolvedType()).toOption + val expectedType = maybeResolvedType + .map { resolvedType => + ExpectedType(typeInfoCalc.fullName(resolvedType), Some(resolvedType)) } - val args = expr.getValues.asScala - .slice(0, MAX_INITIALIZERS) - .flatMap(astsForExpression(_, expectedValueType.getOrElse(ExpectedType.empty))) - .toSeq - - val ast = callAst(callNode, args) + .getOrElse(expectedExprType) // resolved target type should be more accurate + val targetAst = astsForExpression(expr.getTarget, expectedType) + val argsAsts = astsForExpression(expr.getValue, expectedType) + val valueType = argsAsts.headOption.flatMap(_.rootType) + + val typeFullName = + targetAst.headOption + .flatMap(_.rootType) + .orElse(valueType) + .orElse(expectedType.fullName) + .getOrElse(TypeConstants.Any) - if expr.getValues.size() > MAX_INITIALIZERS then - val placeholder = NewLiteral() - .typeFullName(TypeConstants.Any) - .code("") - .lineNumber(line(expr)) - .columnNumber(column(expr)) - ast.withChild(Ast(placeholder)).withArgEdge(callNode, placeholder) - else - ast - end astForArrayInitializerExpr - - def astForBinaryExpr(expr: BinaryExpr, expectedType: ExpectedType): Ast = - val operatorName = expr.getOperator match - case BinaryExpr.Operator.OR => Operators.logicalOr - case BinaryExpr.Operator.AND => Operators.logicalAnd - case BinaryExpr.Operator.BINARY_OR => Operators.or - case BinaryExpr.Operator.BINARY_AND => Operators.and - case BinaryExpr.Operator.DIVIDE => Operators.division - case BinaryExpr.Operator.EQUALS => Operators.equals - case BinaryExpr.Operator.GREATER => Operators.greaterThan - case BinaryExpr.Operator.GREATER_EQUALS => Operators.greaterEqualsThan - case BinaryExpr.Operator.LESS => Operators.lessThan - case BinaryExpr.Operator.LESS_EQUALS => Operators.lessEqualsThan - case BinaryExpr.Operator.LEFT_SHIFT => Operators.shiftLeft - case BinaryExpr.Operator.SIGNED_RIGHT_SHIFT => Operators.logicalShiftRight - case BinaryExpr.Operator.UNSIGNED_RIGHT_SHIFT => Operators.arithmeticShiftRight - case BinaryExpr.Operator.XOR => Operators.xor - case BinaryExpr.Operator.NOT_EQUALS => Operators.notEquals - case BinaryExpr.Operator.PLUS => Operators.addition - case BinaryExpr.Operator.MINUS => Operators.subtraction - case BinaryExpr.Operator.MULTIPLY => Operators.multiplication - case BinaryExpr.Operator.REMAINDER => Operators.modulo - - val args = - astsForExpression(expr.getLeft, expectedType) ++ astsForExpression( - expr.getRight, - expectedType - ) + val code = + s"${targetAst.rootCodeOrEmpty} ${expr.getOperator.asString} ${argsAsts.rootCodeOrEmpty}" - val typeFullName = - expressionReturnTypeFullName(expr) - .orElse(args.headOption.flatMap(_.rootType)) - .orElse(args.lastOption.flatMap(_.rootType)) - .orElse(expectedType.fullName) - .getOrElse(TypeConstants.Any) + val callNode = + newOperatorCallNode(operatorName, code, Some(typeFullName), line(expr), column(expr)) - val callNode = newOperatorCallNode( - operatorName, - code = expr.toString, - typeFullName = Some(typeFullName), - line = line(expr), - column = column(expr) + if partialConstructorQueue.isEmpty then + val assignAst = callAst(callNode, targetAst ++ argsAsts) + Seq(assignAst) + else + if partialConstructorQueue.size > 1 then + logger.debug( + "BUG: Received multiple partial constructors from assignment. Dropping all but the first." ) - - callAst(callNode, args) - end astForBinaryExpr - - private def astForCastExpr(expr: CastExpr, expectedType: ExpectedType): Ast = + val partialConstructor = partialConstructorQueue.head + partialConstructorQueue.clear() + + targetAst.flatMap(_.root).toList match + case List(identifier: NewIdentifier) => + // In this case we have a simple assign. No block needed. + // e.g. Foo f = new Foo(); + val initAst = + completeInitForConstructor(partialConstructor, Ast(identifier.copy)) + Seq(callAst(callNode, targetAst ++ argsAsts), initAst) + + case _ => + // In this case the left hand side is more complex than an identifier, so + // we need to contain the constructor in a block. + // e.g. items[10] = new Foo(); + val valueAst = partialConstructor.blockAst + Seq(callAst(callNode, targetAst ++ Seq(valueAst))) + end if + end astsForAssignExpr + + private def localsForVarDecl(varDecl: VariableDeclarationExpr): List[NewLocal] = + varDecl.getVariables.asScala.map { variable => + val name = variable.getName.toString val typeFullName = - typeInfoCalc - .fullName(expr.getType) - .orElse(expectedType.fullName) + tryWithSafeStackOverflow(typeInfoCalc.fullName(variable.getType)).toOption.flatten + .orElse(scope.lookupType(variable.getTypeAsString)) .getOrElse(TypeConstants.Any) - - val callNode = newOperatorCallNode( - Operators.cast, - code = expr.toString, - typeFullName = Some(typeFullName), - line = line(expr), - column = column(expr) - ) - - val typeNode = NewTypeRef() - .code(expr.getType.toString) - .lineNumber(line(expr)) - .columnNumber(column(expr)) + val code = s"${variable.getType} $name" + NewLocal() + .name(name) + .code(code) .typeFullName(typeFullName) - val typeAst = Ast(typeNode) - - val exprAst = astsForExpression(expr.getExpression, ExpectedType.empty) - - callAst(callNode, Seq(typeAst) ++ exprAst) - end astForCastExpr - - private def astsForAssignExpr(expr: AssignExpr, expectedExprType: ExpectedType): Seq[Ast] = - val operatorName = expr.getOperator match - case Operator.ASSIGN => Operators.assignment - case Operator.PLUS => Operators.assignmentPlus - case Operator.MINUS => Operators.assignmentMinus - case Operator.MULTIPLY => Operators.assignmentMultiplication - case Operator.DIVIDE => Operators.assignmentDivision - case Operator.BINARY_AND => Operators.assignmentAnd - case Operator.BINARY_OR => Operators.assignmentOr - case Operator.XOR => Operators.assignmentXor - case Operator.REMAINDER => Operators.assignmentModulo - case Operator.LEFT_SHIFT => Operators.assignmentShiftLeft - case Operator.SIGNED_RIGHT_SHIFT => Operators.assignmentArithmeticShiftRight - case Operator.UNSIGNED_RIGHT_SHIFT => Operators.assignmentLogicalShiftRight - - val maybeResolvedType = Try(expr.getTarget.calculateResolvedType()).toOption - val expectedType = maybeResolvedType - .map { resolvedType => - ExpectedType(typeInfoCalc.fullName(resolvedType), Some(resolvedType)) - } - .getOrElse(expectedExprType) // resolved target type should be more accurate - val targetAst = astsForExpression(expr.getTarget, expectedType) - val argsAsts = astsForExpression(expr.getValue, expectedType) - val valueType = argsAsts.headOption.flatMap(_.rootType) - - val typeFullName = - targetAst.headOption - .flatMap(_.rootType) - .orElse(valueType) - .orElse(expectedType.fullName) - .getOrElse(TypeConstants.Any) + .lineNumber(line(varDecl)) + .columnNumber(column(varDecl)) + }.toList - val code = - s"${targetAst.rootCodeOrEmpty} ${expr.getOperator.asString} ${argsAsts.rootCodeOrEmpty}" - - val callNode = - newOperatorCallNode(operatorName, code, Some(typeFullName), line(expr), column(expr)) - - if partialConstructorQueue.isEmpty then - val assignAst = callAst(callNode, targetAst ++ argsAsts) - Seq(assignAst) - else - if partialConstructorQueue.size > 1 then - logger.debug( - "BUG: Received multiple partial constructors from assignment. Dropping all but the first." - ) - val partialConstructor = partialConstructorQueue.head - partialConstructorQueue.clear() - - targetAst.flatMap(_.root).toList match - case List(identifier: NewIdentifier) => - // In this case we have a simple assign. No block needed. - // e.g. Foo f = new Foo(); - val initAst = - completeInitForConstructor(partialConstructor, Ast(identifier.copy)) - Seq(callAst(callNode, targetAst ++ argsAsts), initAst) - - case _ => - // In this case the left hand side is more complex than an identifier, so - // we need to contain the constructor in a block. - // e.g. items[10] = new Foo(); - val valueAst = partialConstructor.blockAst - Seq(callAst(callNode, targetAst ++ Seq(valueAst))) - end if - end astsForAssignExpr - - private def localsForVarDecl(varDecl: VariableDeclarationExpr): List[NewLocal] = - varDecl.getVariables.asScala.map { variable => - val name = variable.getName.toString - val typeFullName = - tryWithSafeStackOverflow(typeInfoCalc.fullName(variable.getType)).toOption.flatten - .orElse(scope.lookupType(variable.getTypeAsString)) - .getOrElse(TypeConstants.Any) - val code = s"${variable.getType} $name" - NewLocal() - .name(name) - .code(code) - .typeFullName(typeFullName) - .lineNumber(line(varDecl)) - .columnNumber(column(varDecl)) - }.toList + private def copyAstForVarDeclInit(targetAst: Ast): Ast = + targetAst.root match + case Some(identifier: NewIdentifier) => Ast(identifier.copy) - private def copyAstForVarDeclInit(targetAst: Ast): Ast = - targetAst.root match - case Some(identifier: NewIdentifier) => Ast(identifier.copy) + case Some(fieldAccess: NewCall) if fieldAccess.name == Operators.fieldAccess => + val maybeIdentifier = targetAst.nodes.collectFirst { + case node if node.isInstanceOf[NewIdentifier] => node + } + val maybeField = targetAst.nodes.collectFirst { + case node if node.isInstanceOf[NewFieldIdentifier] => node + } - case Some(fieldAccess: NewCall) if fieldAccess.name == Operators.fieldAccess => - val maybeIdentifier = targetAst.nodes.collectFirst { - case node if node.isInstanceOf[NewIdentifier] => node - } - val maybeField = targetAst.nodes.collectFirst { - case node if node.isInstanceOf[NewFieldIdentifier] => node - } + (maybeIdentifier, maybeField) match + case (Some(identifier), Some(fieldIdentifier)) => + val args = List(identifier, fieldIdentifier).map(node => Ast(node.copy)) + callAst(fieldAccess.copy, args) - (maybeIdentifier, maybeField) match - case (Some(identifier), Some(fieldIdentifier)) => - val args = List(identifier, fieldIdentifier).map(node => Ast(node.copy)) - callAst(fieldAccess.copy, args) - - case _ => - logger.debug( - s"Attempting to copy field access without required children: ${fieldAccess.code}" - ) - Ast() - - case Some(root) => - logger.debug(s"Attempting to copy unhandled root type for var decl init: $root") - Ast() - - case None => - Ast() - - private def assignmentsForVarDecl( - variables: Iterable[VariableDeclarator], - lineNumber: Option[Integer], - columnNumber: Option[Integer] - ): Seq[Ast] = - val variablesWithInitializers = - variables.filter(_.getInitializer.toScala.isDefined) - val assignments = variablesWithInitializers.flatMap { variable => - val name = variable.getName.toString - val initializer = variable.getInitializer.toScala.get // Won't crash because of filter - val initializerTypeFullName = - variable.getInitializer.toScala.flatMap(expressionReturnTypeFullName) - val javaParserVarType = variable.getTypeAsString - val variableTypeFullName = - tryWithSafeStackOverflow(typeInfoCalc.fullName(variable.getType)).toOption.flatten - // TODO: Surely the variable being declared can't already be in scope? - .orElse(scope.lookupVariable(name).typeFullName) - .orElse(scope.lookupType(javaParserVarType)) - - val typeFullName = - variableTypeFullName.orElse(initializerTypeFullName) - - // Need the actual resolvedType here for when the RHS is a lambda expression. - val resolvedExpectedType = - tryWithSafeStackOverflow(symbolSolver.toResolvedType( - variable.getType, - classOf[ResolvedType] - )).toOption - val initializerAsts = - astsForExpression(initializer, ExpectedType(typeFullName, resolvedExpectedType)) - - val typeName = typeFullName - .map(TypeNodePass.fullToShortName) - .getOrElse(guessTypeFullName(variable.getTypeAsString)) - val code = s"$typeName $name = ${initializerAsts.rootCodeOrEmpty}" - - val callNode = newOperatorCallNode( - Operators.assignment, - code, + case _ => + logger.debug( + s"Attempting to copy field access without required children: ${fieldAccess.code}" + ) + Ast() + + case Some(root) => + logger.debug(s"Attempting to copy unhandled root type for var decl init: $root") + Ast() + + case None => + Ast() + + private def assignmentsForVarDecl( + variables: Iterable[VariableDeclarator], + lineNumber: Option[Integer], + columnNumber: Option[Integer] + ): Seq[Ast] = + val variablesWithInitializers = + variables.filter(_.getInitializer.toScala.isDefined) + val assignments = variablesWithInitializers.flatMap { variable => + val name = variable.getName.toString + val initializer = variable.getInitializer.toScala.get // Won't crash because of filter + val initializerTypeFullName = + variable.getInitializer.toScala.flatMap(expressionReturnTypeFullName) + val javaParserVarType = variable.getTypeAsString + val variableTypeFullName = + tryWithSafeStackOverflow(typeInfoCalc.fullName(variable.getType)).toOption.flatten + // TODO: Surely the variable being declared can't already be in scope? + .orElse(scope.lookupVariable(name).typeFullName) + .orElse(scope.lookupType(javaParserVarType)) + + val typeFullName = + variableTypeFullName.orElse(initializerTypeFullName) + + // Need the actual resolvedType here for when the RHS is a lambda expression. + val resolvedExpectedType = + tryWithSafeStackOverflow(symbolSolver.toResolvedType( + variable.getType, + classOf[ResolvedType] + )).toOption + val initializerAsts = + astsForExpression(initializer, ExpectedType(typeFullName, resolvedExpectedType)) + + val typeName = typeFullName + .map(TypeNodePass.fullToShortName) + .getOrElse(guessTypeFullName(variable.getTypeAsString)) + val code = s"$typeName $name = ${initializerAsts.rootCodeOrEmpty}" + + val callNode = newOperatorCallNode( + Operators.assignment, + code, + typeFullName, + lineNumber, + columnNumber + ) + + val targetAst = scope.lookupVariable(name).getVariable() match + // TODO: This definitely feels like a bug. Why is the found member not being used for anything? + case Some(ScopeMember(_, false)) => + val thisType = scope.enclosingTypeDeclFullName + fieldAccessAst( + NameConstants.This, + thisType, + name, typeFullName, - lineNumber, - columnNumber + line(variable), + column(variable) ) - val targetAst = scope.lookupVariable(name).getVariable() match - // TODO: This definitely feels like a bug. Why is the found member not being used for anything? - case Some(ScopeMember(_, false)) => - val thisType = scope.enclosingTypeDeclFullName - fieldAccessAst( - NameConstants.This, - thisType, - name, - typeFullName, - line(variable), - column(variable) - ) - - case maybeCorrespNode => - val identifier = identifierNode( - variable, - name, - name, - typeFullName.getOrElse(TypeConstants.Any) - ) - Ast(identifier).withRefEdges(identifier, maybeCorrespNode.map(_.node).toList) - - // Since all partial constructors will be dealt with here, don't pass them up. - val declAst = callAst(callNode, Seq(targetAst) ++ initializerAsts) - - val constructorAsts = partialConstructorQueue.map(completeInitForConstructor( - _, - copyAstForVarDeclInit(targetAst) - )) - partialConstructorQueue.clear() - - Seq(declAst) ++ constructorAsts - } - - assignments.toList - end assignmentsForVarDecl - - private def completeInitForConstructor( - partialConstructor: PartialConstructor, - targetAst: Ast - ): Ast = - val initNode = partialConstructor.initNode - val args = partialConstructor.initArgs - - targetAst.root match - case Some(identifier: NewIdentifier) => - scope.lookupVariable(identifier.name).variableNode.foreach { variableNode => - diffGraph.addEdge(identifier, variableNode, EdgeTypes.REF) - } - - case _ => // Nothing to do in this case - callAst(initNode, args.toList, Some(targetAst)) - - private def astsForVariableDecl(varDecl: VariableDeclarationExpr): Seq[Ast] = - val locals = localsForVarDecl(varDecl) - val localAsts = locals.map { Ast(_) } - - locals.foreach { local => - scope.addLocal(local) - } - - val assignments = - assignmentsForVarDecl(varDecl.getVariables.asScala, line(varDecl), column(varDecl)) - - localAsts ++ assignments + case maybeCorrespNode => + val identifier = identifierNode( + variable, + name, + name, + typeFullName.getOrElse(TypeConstants.Any) + ) + Ast(identifier).withRefEdges(identifier, maybeCorrespNode.map(_.node).toList) + + // Since all partial constructors will be dealt with here, don't pass them up. + val declAst = callAst(callNode, Seq(targetAst) ++ initializerAsts) + + val constructorAsts = partialConstructorQueue.map(completeInitForConstructor( + _, + copyAstForVarDeclInit(targetAst) + )) + partialConstructorQueue.clear() + + Seq(declAst) ++ constructorAsts + } + + assignments.toList + end assignmentsForVarDecl + + private def completeInitForConstructor( + partialConstructor: PartialConstructor, + targetAst: Ast + ): Ast = + val initNode = partialConstructor.initNode + val args = partialConstructor.initArgs + + targetAst.root match + case Some(identifier: NewIdentifier) => + scope.lookupVariable(identifier.name).variableNode.foreach { variableNode => + diffGraph.addEdge(identifier, variableNode, EdgeTypes.REF) + } + + case _ => // Nothing to do in this case + callAst(initNode, args.toList, Some(targetAst)) + + private def astsForVariableDecl(varDecl: VariableDeclarationExpr): Seq[Ast] = + val locals = localsForVarDecl(varDecl) + val localAsts = locals.map { Ast(_) } + + locals.foreach { local => + scope.addLocal(local) + } + + val assignments = + assignmentsForVarDecl(varDecl.getVariables.asScala, line(varDecl), column(varDecl)) + + localAsts ++ assignments + + private def astForClassExpr(expr: ClassExpr): Ast = + val someTypeFullName = Some(TypeConstants.Class) + val callNode = newOperatorCallNode( + Operators.fieldAccess, + expr.toString, + someTypeFullName, + line(expr), + column(expr) + ) + + val identifierType = typeInfoCalc.fullName(expr.getType) + val identifier = identifierNode( + expr, + expr.getTypeAsString, + expr.getTypeAsString, + identifierType.getOrElse("ANY") + ) + val idAst = Ast(identifier) + + val fieldIdentifier = NewFieldIdentifier() + .canonicalName("class") + .code("class") + .lineNumber(line(expr)) + .columnNumber(column(expr)) + val fieldIdAst = Ast(fieldIdentifier) + + callAst(callNode, Seq(idAst, fieldIdAst)) + end astForClassExpr + + private def astForConditionalExpr(expr: ConditionalExpr, expectedType: ExpectedType): Ast = + val condAst = astsForExpression(expr.getCondition, ExpectedType.Boolean) + val thenAst = astsForExpression(expr.getThenExpr, expectedType) + val elseAst = astsForExpression(expr.getElseExpr, expectedType) + + val typeFullName = + expressionReturnTypeFullName(expr) + .orElse(thenAst.headOption.flatMap(_.rootType)) + .orElse(elseAst.headOption.flatMap(_.rootType)) + .orElse(expectedType.fullName) + .getOrElse(TypeConstants.Any) - private def astForClassExpr(expr: ClassExpr): Ast = - val someTypeFullName = Some(TypeConstants.Class) - val callNode = newOperatorCallNode( - Operators.fieldAccess, + val callNode = + newOperatorCallNode( + Operators.conditional, expr.toString, - someTypeFullName, + Some(typeFullName), line(expr), column(expr) ) - val identifierType = typeInfoCalc.fullName(expr.getType) - val identifier = identifierNode( - expr, - expr.getTypeAsString, - expr.getTypeAsString, - identifierType.getOrElse("ANY") - ) - val idAst = Ast(identifier) - - val fieldIdentifier = NewFieldIdentifier() - .canonicalName("class") - .code("class") - .lineNumber(line(expr)) - .columnNumber(column(expr)) - val fieldIdAst = Ast(fieldIdentifier) - - callAst(callNode, Seq(idAst, fieldIdAst)) - end astForClassExpr - - private def astForConditionalExpr(expr: ConditionalExpr, expectedType: ExpectedType): Ast = - val condAst = astsForExpression(expr.getCondition, ExpectedType.Boolean) - val thenAst = astsForExpression(expr.getThenExpr, expectedType) - val elseAst = astsForExpression(expr.getElseExpr, expectedType) - - val typeFullName = - expressionReturnTypeFullName(expr) - .orElse(thenAst.headOption.flatMap(_.rootType)) - .orElse(elseAst.headOption.flatMap(_.rootType)) - .orElse(expectedType.fullName) - .getOrElse(TypeConstants.Any) - - val callNode = - newOperatorCallNode( - Operators.conditional, - expr.toString, - Some(typeFullName), - line(expr), - column(expr) - ) + callAst(callNode, condAst ++ thenAst ++ elseAst) + end astForConditionalExpr - callAst(callNode, condAst ++ thenAst ++ elseAst) - end astForConditionalExpr + private def astForEnclosedExpression(expr: EnclosedExpr, expectedType: ExpectedType): Seq[Ast] = + astsForExpression(expr.getInner, expectedType) - private def astForEnclosedExpression(expr: EnclosedExpr, expectedType: ExpectedType): Seq[Ast] = - astsForExpression(expr.getInner, expectedType) - - private def astForFieldAccessExpr(expr: FieldAccessExpr, expectedType: ExpectedType): Ast = - val typeFullName = - expressionReturnTypeFullName(expr) - .orElse(expectedType.fullName) - .getOrElse(TypeConstants.Any) - - val callNode = - newOperatorCallNode( - Operators.fieldAccess, - expr.toString, - Some(typeFullName), - line(expr), - column(expr) - ) - - val fieldIdentifier = expr.getName - val identifierAsts = astsForExpression(expr.getScope, ExpectedType.empty) - val fieldIdentifierNode = NewFieldIdentifier() - .canonicalName(fieldIdentifier.toString) - .lineNumber(line(fieldIdentifier)) - .columnNumber(column(fieldIdentifier)) - .code(fieldIdentifier.toString) - val fieldIdAst = Ast(fieldIdentifierNode) - - callAst(callNode, identifierAsts ++ Seq(fieldIdAst)) - end astForFieldAccessExpr - - private def astForInstanceOfExpr(expr: InstanceOfExpr): Ast = - val booleanTypeFullName = Some(TypeConstants.Boolean) - val callNode = - newOperatorCallNode( - Operators.instanceOf, - expr.toString, - booleanTypeFullName, - line(expr), - column(expr) - ) - - val exprAst = astsForExpression(expr.getExpression, ExpectedType.empty) - val typeFullName = typeInfoCalc.fullName(expr.getType).getOrElse(TypeConstants.Any) - val typeNode = - NewTypeRef() - .code(expr.getType.toString) - .lineNumber(line(expr)) - .columnNumber(column(expr.getType)) - .typeFullName(typeFullName) - val typeAst = Ast(typeNode) - - callAst(callNode, exprAst ++ Seq(typeAst)) - end astForInstanceOfExpr - - private def fieldAccessAst( - identifierName: String, - identifierType: Option[String], - fieldIdentifierName: String, - returnType: Option[String], - lineNo: Option[Integer], - columnNo: Option[Integer] - ): Ast = - val typeFullName = - identifierType.orElse(Some(TypeConstants.Any)).map(typeInfoCalc.registerType) - val identifier = newIdentifierNode(identifierName, typeFullName.getOrElse("ANY")) - val maybeCorrespNode = scope.lookupVariable(identifierName).variableNode - - val fieldIdentifier = NewFieldIdentifier() - .code(fieldIdentifierName) - .canonicalName(fieldIdentifierName) - .lineNumber(lineNo) - .columnNumber(columnNo) - - val fieldAccessCode = s"$identifierName.$fieldIdentifierName" - val fieldAccess = - newOperatorCallNode( - Operators.fieldAccess, - fieldAccessCode, - returnType.orElse(Some(TypeConstants.Any)), - lineNo, - columnNo - ) - - val identifierAst = Ast(identifier) - val fieldIdentAst = Ast(fieldIdentifier) - - callAst(fieldAccess, Seq(identifierAst, fieldIdentAst)) - .withRefEdges(identifier, maybeCorrespNode.toList) - end fieldAccessAst - - private def astForNameExpr(nameExpr: NameExpr, expectedType: ExpectedType): Ast = - val name = nameExpr.getName.toString - val typeFullName = expressionReturnTypeFullName(nameExpr) + private def astForFieldAccessExpr(expr: FieldAccessExpr, expectedType: ExpectedType): Ast = + val typeFullName = + expressionReturnTypeFullName(expr) .orElse(expectedType.fullName) - .map(typeInfoCalc.registerType) - - tryWithSafeStackOverflow(nameExpr.resolve()) match - case Success(value) if value.isField => - val identifierName = if value.asField.isStatic then - // A static field represented by a NameExpr must belong to the class in which it's used. Static fields - // from other classes are represented by a FieldAccessExpr instead. - scope.enclosingTypeDecl.map(_.name).getOrElse( - guessTypeFullName(name) - ) - else - NameConstants.This - - val identifierTypeFullName = - value match - case fieldDecl: ResolvedFieldDeclaration => - // TODO It is not quite correct to use the declaring classes type. - // Instead we should take the using classes type which is either the same or a - // sub class of the declaring class. - typeInfoCalc.fullName(fieldDecl.declaringType()) - - fieldAccessAst( - identifierName, - identifierTypeFullName, - name, - typeFullName, - line(nameExpr), - column(nameExpr) - ) - - case _ => - val identifier = - identifierNode(nameExpr, name, name, typeFullName.getOrElse(TypeConstants.Any)) - - val variableOption = scope - .lookupVariable(name) - .variableNode - .collect { - case parameter: NewMethodParameterIn => parameter - - case local: NewLocal => local - } - - variableOption.foldLeft(Ast(identifier))((ast, variableNode) => - ast.withRefEdge(identifier, variableNode) - ) - end match - end astForNameExpr - - private def argumentTypesForMethodLike( - maybeResolvedMethodLike: Try[ResolvedMethodLikeDeclaration] - ): Option[List[String]] = - maybeResolvedMethodLike.toOption - .flatMap(calcParameterTypes(_, ResolvedTypeParametersMap.empty())) - - private def initNode( - namespaceName: Option[String], - argumentTypes: Option[List[String]], - argsSize: Int, - code: String, - lineNumber: Option[Integer] = None, - columnNumber: Option[Integer] = None - ): NewCall = - val initSignature = argumentTypes match - case Some(tpe) => composeMethodLikeSignature(TypeConstants.Void, tpe) - case _ if argsSize == 0 => composeMethodLikeSignature(TypeConstants.Void, Nil) - case _ => composeUnresolvedSignature(argsSize) - val namespace = namespaceName.getOrElse(Defines.UnresolvedNamespace) - val initMethodFullName = - composeMethodFullName(namespace, Defines.ConstructorMethodName, initSignature) - NewCall() - .name(Defines.ConstructorMethodName) - .methodFullName(initMethodFullName) - .signature(initSignature) - .typeFullName(TypeConstants.Void) - .code(code) - .dispatchType(DispatchTypes.STATIC_DISPATCH) - .lineNumber(lineNumber) - .columnNumber(columnNumber) - end initNode - - /** The below representation for constructor invocations and object creations was chosen for the - * sake of consistency with the Java frontend. It follows the bytecode approach of splitting a - * constructor call into separate `alloc` and `init` calls. - * - * There are two cases to consider. The first is a constructor invocation in an assignment, for - * example: - * - * Foo f = new Foo(42); - * - * is represented as - * - * Foo f = .alloc() f.init(42); - * - * The second case is a constructor invocation not in an assignment, for example as an argument - * to a method call. In this case, the representation does not stay as close to Java as in case - * 1. In particular, a new BLOCK is introduced to contain the constructor invocation. For - * example: - * - * foo(new Foo(42)); - * - * is represented as - * - * foo({ Foo temp = alloc(); temp.init(42); temp }) - * - * This is not valid Java code, but this representation is a decent compromise between staying - * faithful to Java and being consistent with the Java bytecode frontend. - */ - private def astForObjectCreationExpr( - expr: ObjectCreationExpr, - expectedType: ExpectedType - ): Ast = - val maybeResolvedExpr = tryWithSafeStackOverflow(expr.resolve()) - val argumentAsts = argAstsForCall(expr, maybeResolvedExpr, expr.getArguments) - - val typeFullName = - tryWithSafeStackOverflow(typeInfoCalc.fullName(expr.getType)).toOption.flatten - .orElse(scope.lookupType(expr.getTypeAsString)) - .orElse(expectedType.fullName) + .getOrElse(TypeConstants.Any) - val argumentTypes = argumentTypesForMethodLike(maybeResolvedExpr) - - val allocNode = newOperatorCallNode( - Operators.alloc, + val callNode = + newOperatorCallNode( + Operators.fieldAccess, expr.toString, - typeFullName.orElse(Some(TypeConstants.Any)), + Some(typeFullName), line(expr), column(expr) ) - val initCall = initNode( - typeFullName.orElse(Some(TypeConstants.Any)), - argumentTypes, - argumentAsts.size, + val fieldIdentifier = expr.getName + val identifierAsts = astsForExpression(expr.getScope, ExpectedType.empty) + val fieldIdentifierNode = NewFieldIdentifier() + .canonicalName(fieldIdentifier.toString) + .lineNumber(line(fieldIdentifier)) + .columnNumber(column(fieldIdentifier)) + .code(fieldIdentifier.toString) + val fieldIdAst = Ast(fieldIdentifierNode) + + callAst(callNode, identifierAsts ++ Seq(fieldIdAst)) + end astForFieldAccessExpr + + private def astForInstanceOfExpr(expr: InstanceOfExpr): Ast = + val booleanTypeFullName = Some(TypeConstants.Boolean) + val callNode = + newOperatorCallNode( + Operators.instanceOf, expr.toString, - line(expr) - ) - - // Assume that a block ast is required, since there isn't enough information to decide otherwise. - // This simplifies logic elsewhere, and unnecessary blocks will be garbage collected soon. - val blockAst = blockAstForConstructorInvocation( + booleanTypeFullName, line(expr), - column(expr), - allocNode, - initCall, - argumentAsts + column(expr) ) - expr.getParentNode.toScala match - case Some(parent) - if parent.isInstanceOf[VariableDeclarator] || parent.isInstanceOf[AssignExpr] => - val partialConstructor = PartialConstructor(initCall, argumentAsts, blockAst) - partialConstructorQueue.append(partialConstructor) - Ast(allocNode) - - case _ => - blockAst - end astForObjectCreationExpr - - private var tempConstCount = 0 - private def blockAstForConstructorInvocation( - lineNumber: Option[Integer], - columnNumber: Option[Integer], - allocNode: NewCall, - initNode: NewCall, - args: Seq[Ast] - ): Ast = - val blockNode = NewBlock() - .lineNumber(lineNumber) - .columnNumber(columnNumber) - .typeFullName(allocNode.typeFullName) - - val tempName = "$obj" ++ tempConstCount.toString - tempConstCount += 1 - val identifier = newIdentifierNode(tempName, allocNode.typeFullName) - val identifierAst = Ast(identifier) - - val allocAst = Ast(allocNode) - - val assignmentNode = newOperatorCallNode( - Operators.assignment, - PropertyDefaults.Code, - Some(allocNode.typeFullName) + val exprAst = astsForExpression(expr.getExpression, ExpectedType.empty) + val typeFullName = typeInfoCalc.fullName(expr.getType).getOrElse(TypeConstants.Any) + val typeNode = + NewTypeRef() + .code(expr.getType.toString) + .lineNumber(line(expr)) + .columnNumber(column(expr.getType)) + .typeFullName(typeFullName) + val typeAst = Ast(typeNode) + + callAst(callNode, exprAst ++ Seq(typeAst)) + end astForInstanceOfExpr + + private def fieldAccessAst( + identifierName: String, + identifierType: Option[String], + fieldIdentifierName: String, + returnType: Option[String], + lineNo: Option[Integer], + columnNo: Option[Integer] + ): Ast = + val typeFullName = + identifierType.orElse(Some(TypeConstants.Any)).map(typeInfoCalc.registerType) + val identifier = newIdentifierNode(identifierName, typeFullName.getOrElse("ANY")) + val maybeCorrespNode = scope.lookupVariable(identifierName).variableNode + + val fieldIdentifier = NewFieldIdentifier() + .code(fieldIdentifierName) + .canonicalName(fieldIdentifierName) + .lineNumber(lineNo) + .columnNumber(columnNo) + + val fieldAccessCode = s"$identifierName.$fieldIdentifierName" + val fieldAccess = + newOperatorCallNode( + Operators.fieldAccess, + fieldAccessCode, + returnType.orElse(Some(TypeConstants.Any)), + lineNo, + columnNo ) - val assignmentAst = callAst(assignmentNode, List(identifierAst, allocAst)) - - val identifierWithDefaultOrder = identifier.copy.order(PropertyDefaults.Order) - val identifierForInit = identifierWithDefaultOrder.copy - val initWithDefaultOrder = initNode.order(PropertyDefaults.Order) - val initAst = callAst(initWithDefaultOrder, args, Some(Ast(identifierForInit))) - - val returnAst = Ast(identifierWithDefaultOrder.copy) - - Ast(blockNode) - .withChild(assignmentAst) - .withChild(initAst) - .withChild(returnAst) - end blockAstForConstructorInvocation - - private def astForThisExpr(expr: ThisExpr, expectedType: ExpectedType): Ast = - val typeFullName = - expressionReturnTypeFullName(expr) - .orElse(expectedType.fullName) - - val identifier = - identifierNode(expr, expr.toString, expr.toString, typeFullName.getOrElse("ANY")) - val thisParam = scope.lookupVariable(NameConstants.This).variableNode - - thisParam.foreach { thisNode => - diffGraph.addEdge(identifier, thisNode, EdgeTypes.REF) - } - - Ast(identifier) - - private def astForExplicitConstructorInvocation(stmt: ExplicitConstructorInvocationStmt): Ast = - val maybeResolved = tryWithSafeStackOverflow(stmt.resolve()) - val args = argAstsForCall(stmt, maybeResolved, stmt.getArguments) - val argTypes = argumentTypesForMethodLike(maybeResolved) + val identifierAst = Ast(identifier) + val fieldIdentAst = Ast(fieldIdentifier) + + callAst(fieldAccess, Seq(identifierAst, fieldIdentAst)) + .withRefEdges(identifier, maybeCorrespNode.toList) + end fieldAccessAst + + private def astForNameExpr(nameExpr: NameExpr, expectedType: ExpectedType): Ast = + val name = nameExpr.getName.toString + val typeFullName = expressionReturnTypeFullName(nameExpr) + .orElse(expectedType.fullName) + .map(typeInfoCalc.registerType) + + tryWithSafeStackOverflow(nameExpr.resolve()) match + case Success(value) if value.isField => + val identifierName = if value.asField.isStatic then + // A static field represented by a NameExpr must belong to the class in which it's used. Static fields + // from other classes are represented by a FieldAccessExpr instead. + scope.enclosingTypeDecl.map(_.name).getOrElse( + guessTypeFullName(name) + ) + else + NameConstants.This + + val identifierTypeFullName = + value match + case fieldDecl: ResolvedFieldDeclaration => + // TODO It is not quite correct to use the declaring classes type. + // Instead we should take the using classes type which is either the same or a + // sub class of the declaring class. + typeInfoCalc.fullName(fieldDecl.declaringType()) + + fieldAccessAst( + identifierName, + identifierTypeFullName, + name, + typeFullName, + line(nameExpr), + column(nameExpr) + ) - val typeFullName = maybeResolved.toOption - .map(_.declaringType()) - .flatMap(typeInfoCalc.fullName) + case _ => + val identifier = + identifierNode(nameExpr, name, name, typeFullName.getOrElse(TypeConstants.Any)) - val callRoot = initNode( - typeFullName.orElse(Some(TypeConstants.Any)), - argTypes, - args.size, - stmt.toString, - line(stmt), - column(stmt) - ) + val variableOption = scope + .lookupVariable(name) + .variableNode + .collect { + case parameter: NewMethodParameterIn => parameter - val thisNode = - newIdentifierNode(NameConstants.This, typeFullName.getOrElse(TypeConstants.Any)) - scope.lookupVariable(NameConstants.This).variableNode.foreach { thisParam => - diffGraph.addEdge(thisNode, thisParam, EdgeTypes.REF) - } - val thisAst = Ast(thisNode) - - callAst(callRoot, args, Some(thisAst)) - end astForExplicitConstructorInvocation - - private def astsForExpression(expression: Expression, expectedType: ExpectedType): Seq[Ast] = - // TODO: Implement missing handlers - // case _: MethodReferenceExpr => Seq() - // case _: PatternExpr => Seq() - // case _: SuperExpr => Seq() - // case _: SwitchExpr => Seq() - // case _: TypeExpr => Seq() - expression match - case _: AnnotationExpr => Seq() - case x: ArrayAccessExpr => Seq(astForArrayAccessExpr(x, expectedType)) - case x: ArrayCreationExpr => Seq(astForArrayCreationExpr(x, expectedType)) - case x: ArrayInitializerExpr => Seq(astForArrayInitializerExpr(x, expectedType)) - case x: AssignExpr => astsForAssignExpr(x, expectedType) - case x: BinaryExpr => Seq(astForBinaryExpr(x, expectedType)) - case x: CastExpr => Seq(astForCastExpr(x, expectedType)) - case x: ClassExpr => Seq(astForClassExpr(x)) - case x: ConditionalExpr => Seq(astForConditionalExpr(x, expectedType)) - case x: EnclosedExpr => astForEnclosedExpression(x, expectedType) - case x: FieldAccessExpr => Seq(astForFieldAccessExpr(x, expectedType)) - case x: InstanceOfExpr => Seq(astForInstanceOfExpr(x)) - case x: LambdaExpr => Seq(astForLambdaExpr(x, expectedType)) - case x: LiteralExpr => Seq(astForLiteralExpr(x)) - case x: MethodCallExpr => Seq(astForMethodCall(x, expectedType)) - case x: NameExpr => Seq(astForNameExpr(x, expectedType)) - case x: ObjectCreationExpr => Seq(astForObjectCreationExpr(x, expectedType)) - case x: SuperExpr => Seq(astForSuperExpr(x, expectedType)) - case x: ThisExpr => Seq(astForThisExpr(x, expectedType)) - case x: UnaryExpr => Seq(astForUnaryExpr(x, expectedType)) - case x: VariableDeclarationExpr => astsForVariableDecl(x) - case x => Seq(unknownAst(x)) - - private def unknownAst(node: Node): Ast = Ast(unknownNode(node, node.toString)) - - private def someWithDotSuffix(prefix: String): Option[String] = Some(s"$prefix.") - - private def codeForScopeExpr( - scopeExpr: Expression, - isScopeForStaticCall: Boolean - ): Option[String] = - scopeExpr match - case scope: NameExpr => someWithDotSuffix(scope.getNameAsString) - - case fieldAccess: FieldAccessExpr => - val maybeScopeString = - codeForScopeExpr(fieldAccess.getScope, isScopeForStaticCall = false) - val name = fieldAccess.getNameAsString - maybeScopeString - .map { scopeString => - s"$scopeString$name" - } - .orElse(Some(name)) - .flatMap(someWithDotSuffix) - - case _: SuperExpr => someWithDotSuffix(NameConstants.Super) - - case _: ThisExpr => someWithDotSuffix(NameConstants.This) - - case scopeMethodCall: MethodCallExpr => - codePrefixForMethodCall(scopeMethodCall) match - case "" => Some("") - case prefix => - val argumentsCode = getArgumentCodeString(scopeMethodCall.getArguments) - someWithDotSuffix( - s"$prefix${scopeMethodCall.getNameAsString}($argumentsCode)" - ) - - case objectCreationExpr: ObjectCreationExpr => - val typeName = objectCreationExpr.getTypeAsString - val argumentsString = getArgumentCodeString(objectCreationExpr.getArguments) - someWithDotSuffix(s"new $typeName($argumentsString)") + case local: NewLocal => local + } - case _ => None + variableOption.foldLeft(Ast(identifier))((ast, variableNode) => + ast.withRefEdge(identifier, variableNode) + ) + end match + end astForNameExpr + + private def argumentTypesForMethodLike( + maybeResolvedMethodLike: Try[ResolvedMethodLikeDeclaration] + ): Option[List[String]] = + maybeResolvedMethodLike.toOption + .flatMap(calcParameterTypes(_, ResolvedTypeParametersMap.empty())) + + private def initNode( + namespaceName: Option[String], + argumentTypes: Option[List[String]], + argsSize: Int, + code: String, + lineNumber: Option[Integer] = None, + columnNumber: Option[Integer] = None + ): NewCall = + val initSignature = argumentTypes match + case Some(tpe) => composeMethodLikeSignature(TypeConstants.Void, tpe) + case _ if argsSize == 0 => composeMethodLikeSignature(TypeConstants.Void, Nil) + case _ => composeUnresolvedSignature(argsSize) + val namespace = namespaceName.getOrElse(Defines.UnresolvedNamespace) + val initMethodFullName = + composeMethodFullName(namespace, Defines.ConstructorMethodName, initSignature) + NewCall() + .name(Defines.ConstructorMethodName) + .methodFullName(initMethodFullName) + .signature(initSignature) + .typeFullName(TypeConstants.Void) + .code(code) + .dispatchType(DispatchTypes.STATIC_DISPATCH) + .lineNumber(lineNumber) + .columnNumber(columnNumber) + end initNode + + /** The below representation for constructor invocations and object creations was chosen for the + * sake of consistency with the Java frontend. It follows the bytecode approach of splitting a + * constructor call into separate `alloc` and `init` calls. + * + * There are two cases to consider. The first is a constructor invocation in an assignment, for + * example: + * + * Foo f = new Foo(42); + * + * is represented as + * + * Foo f = .alloc() f.init(42); + * + * The second case is a constructor invocation not in an assignment, for example as an argument + * to a method call. In this case, the representation does not stay as close to Java as in case + * 1. In particular, a new BLOCK is introduced to contain the constructor invocation. For + * example: + * + * foo(new Foo(42)); + * + * is represented as + * + * foo({ Foo temp = alloc(); temp.init(42); temp }) + * + * This is not valid Java code, but this representation is a decent compromise between staying + * faithful to Java and being consistent with the Java bytecode frontend. + */ + private def astForObjectCreationExpr( + expr: ObjectCreationExpr, + expectedType: ExpectedType + ): Ast = + val maybeResolvedExpr = tryWithSafeStackOverflow(expr.resolve()) + val argumentAsts = argAstsForCall(expr, maybeResolvedExpr, expr.getArguments) + + val typeFullName = + tryWithSafeStackOverflow(typeInfoCalc.fullName(expr.getType)).toOption.flatten + .orElse(scope.lookupType(expr.getTypeAsString)) + .orElse(expectedType.fullName) - private def codePrefixForMethodCall(call: MethodCallExpr): String = - tryWithSafeStackOverflow(call.resolve()) match - case Success(resolvedCall) => - call.getScope.toScala - .flatMap(codeForScopeExpr(_, resolvedCall.isStatic)) - .getOrElse(if resolvedCall.isStatic then "" else s"${NameConstants.This}.") - - case _ => - // If the call is unresolvable, we cannot make a good guess about what the prefix should be - "" - - private def createObjectNode( - typeFullName: Option[String], - call: MethodCallExpr, - dispatchType: String - ): Option[NewIdentifier] = - val maybeScope = call.getScope.toScala - - Option.when(maybeScope.isDefined || dispatchType == DispatchTypes.DYNAMIC_DISPATCH) { - val name = maybeScope.map(_.toString).getOrElse(NameConstants.This) - identifierNode(call, name, name, typeFullName.getOrElse("ANY")) - } + val argumentTypes = argumentTypesForMethodLike(maybeResolvedExpr) + + val allocNode = newOperatorCallNode( + Operators.alloc, + expr.toString, + typeFullName.orElse(Some(TypeConstants.Any)), + line(expr), + column(expr) + ) + + val initCall = initNode( + typeFullName.orElse(Some(TypeConstants.Any)), + argumentTypes, + argumentAsts.size, + expr.toString, + line(expr) + ) + + // Assume that a block ast is required, since there isn't enough information to decide otherwise. + // This simplifies logic elsewhere, and unnecessary blocks will be garbage collected soon. + val blockAst = blockAstForConstructorInvocation( + line(expr), + column(expr), + allocNode, + initCall, + argumentAsts + ) + + expr.getParentNode.toScala match + case Some(parent) + if parent.isInstanceOf[VariableDeclarator] || parent.isInstanceOf[AssignExpr] => + val partialConstructor = PartialConstructor(initCall, argumentAsts, blockAst) + partialConstructorQueue.append(partialConstructor) + Ast(allocNode) + + case _ => + blockAst + end astForObjectCreationExpr + + private var tempConstCount = 0 + private def blockAstForConstructorInvocation( + lineNumber: Option[Integer], + columnNumber: Option[Integer], + allocNode: NewCall, + initNode: NewCall, + args: Seq[Ast] + ): Ast = + val blockNode = NewBlock() + .lineNumber(lineNumber) + .columnNumber(columnNumber) + .typeFullName(allocNode.typeFullName) + + val tempName = "$obj" ++ tempConstCount.toString + tempConstCount += 1 + val identifier = newIdentifierNode(tempName, allocNode.typeFullName) + val identifierAst = Ast(identifier) + + val allocAst = Ast(allocNode) + + val assignmentNode = newOperatorCallNode( + Operators.assignment, + PropertyDefaults.Code, + Some(allocNode.typeFullName) + ) + + val assignmentAst = callAst(assignmentNode, List(identifierAst, allocAst)) + + val identifierWithDefaultOrder = identifier.copy.order(PropertyDefaults.Order) + val identifierForInit = identifierWithDefaultOrder.copy + val initWithDefaultOrder = initNode.order(PropertyDefaults.Order) + val initAst = callAst(initWithDefaultOrder, args, Some(Ast(identifierForInit))) + + val returnAst = Ast(identifierWithDefaultOrder.copy) + + Ast(blockNode) + .withChild(assignmentAst) + .withChild(initAst) + .withChild(returnAst) + end blockAstForConstructorInvocation + + private def astForThisExpr(expr: ThisExpr, expectedType: ExpectedType): Ast = + val typeFullName = + expressionReturnTypeFullName(expr) + .orElse(expectedType.fullName) - private def nextLambdaName(): String = - s"$LambdaNamePrefix${lambdaKeyPool.next}" - - private def nextIndexName(): String = - s"$IndexNamePrefix${indexKeyPool.next}" - - private def nextIterableName(): String = - s"$IterableNamePrefix${iterableKeyPool.next}" - - private def genericParamTypeMapForLambda(expectedType: ExpectedType) - : ResolvedTypeParametersMap = - expectedType.resolvedType - // This should always be true for correct code - .collect { case r: ResolvedReferenceType => r } - .map(_.typeParametersMap()) - .getOrElse(new ResolvedTypeParametersMap.Builder().build()) - - private def buildParamListForLambda( - expr: LambdaExpr, - maybeBoundMethod: Option[ResolvedMethodDeclaration], - expectedTypeParamTypes: ResolvedTypeParametersMap - ): Seq[Ast] = - val lambdaParameters = expr.getParameters.asScala.toList - val paramTypesList = maybeBoundMethod match - case Some(resolvedMethod) => - val resolvedParameters = - (0 until resolvedMethod.getNumberOfParams).map(resolvedMethod.getParam) - - // Substitute generic typeParam with the expected type if it can be found; leave unchanged otherwise. - resolvedParameters.map(param => Try(param.getType)).map { - case Success(resolvedType: ResolvedTypeVariable) => - val typ = expectedTypeParamTypes.getValue(resolvedType.asTypeParameter) - typeInfoCalc.fullName(typ) - - case Success(resolvedType) => typeInfoCalc.fullName(resolvedType) - - case Failure(_) => None + val identifier = + identifierNode(expr, expr.toString, expr.toString, typeFullName.getOrElse("ANY")) + val thisParam = scope.lookupVariable(NameConstants.This).variableNode + + thisParam.foreach { thisNode => + diffGraph.addEdge(identifier, thisNode, EdgeTypes.REF) + } + + Ast(identifier) + + private def astForExplicitConstructorInvocation(stmt: ExplicitConstructorInvocationStmt): Ast = + val maybeResolved = tryWithSafeStackOverflow(stmt.resolve()) + val args = argAstsForCall(stmt, maybeResolved, stmt.getArguments) + val argTypes = argumentTypesForMethodLike(maybeResolved) + + val typeFullName = maybeResolved.toOption + .map(_.declaringType()) + .flatMap(typeInfoCalc.fullName) + + val callRoot = initNode( + typeFullName.orElse(Some(TypeConstants.Any)), + argTypes, + args.size, + stmt.toString, + line(stmt), + column(stmt) + ) + + val thisNode = + newIdentifierNode(NameConstants.This, typeFullName.getOrElse(TypeConstants.Any)) + scope.lookupVariable(NameConstants.This).variableNode.foreach { thisParam => + diffGraph.addEdge(thisNode, thisParam, EdgeTypes.REF) + } + val thisAst = Ast(thisNode) + + callAst(callRoot, args, Some(thisAst)) + end astForExplicitConstructorInvocation + + private def astsForExpression(expression: Expression, expectedType: ExpectedType): Seq[Ast] = + // TODO: Implement missing handlers + // case _: MethodReferenceExpr => Seq() + // case _: PatternExpr => Seq() + // case _: SuperExpr => Seq() + // case _: SwitchExpr => Seq() + // case _: TypeExpr => Seq() + expression match + case _: AnnotationExpr => Seq() + case x: ArrayAccessExpr => Seq(astForArrayAccessExpr(x, expectedType)) + case x: ArrayCreationExpr => Seq(astForArrayCreationExpr(x, expectedType)) + case x: ArrayInitializerExpr => Seq(astForArrayInitializerExpr(x, expectedType)) + case x: AssignExpr => astsForAssignExpr(x, expectedType) + case x: BinaryExpr => Seq(astForBinaryExpr(x, expectedType)) + case x: CastExpr => Seq(astForCastExpr(x, expectedType)) + case x: ClassExpr => Seq(astForClassExpr(x)) + case x: ConditionalExpr => Seq(astForConditionalExpr(x, expectedType)) + case x: EnclosedExpr => astForEnclosedExpression(x, expectedType) + case x: FieldAccessExpr => Seq(astForFieldAccessExpr(x, expectedType)) + case x: InstanceOfExpr => Seq(astForInstanceOfExpr(x)) + case x: LambdaExpr => Seq(astForLambdaExpr(x, expectedType)) + case x: LiteralExpr => Seq(astForLiteralExpr(x)) + case x: MethodCallExpr => Seq(astForMethodCall(x, expectedType)) + case x: NameExpr => Seq(astForNameExpr(x, expectedType)) + case x: ObjectCreationExpr => Seq(astForObjectCreationExpr(x, expectedType)) + case x: SuperExpr => Seq(astForSuperExpr(x, expectedType)) + case x: ThisExpr => Seq(astForThisExpr(x, expectedType)) + case x: UnaryExpr => Seq(astForUnaryExpr(x, expectedType)) + case x: VariableDeclarationExpr => astsForVariableDecl(x) + case x => Seq(unknownAst(x)) + + private def unknownAst(node: Node): Ast = Ast(unknownNode(node, node.toString)) + + private def someWithDotSuffix(prefix: String): Option[String] = Some(s"$prefix.") + + private def codeForScopeExpr( + scopeExpr: Expression, + isScopeForStaticCall: Boolean + ): Option[String] = + scopeExpr match + case scope: NameExpr => someWithDotSuffix(scope.getNameAsString) + + case fieldAccess: FieldAccessExpr => + val maybeScopeString = + codeForScopeExpr(fieldAccess.getScope, isScopeForStaticCall = false) + val name = fieldAccess.getNameAsString + maybeScopeString + .map { scopeString => + s"$scopeString$name" } + .orElse(Some(name)) + .flatMap(someWithDotSuffix) - case None => - // Unless types are explicitly specified in the lambda definition, - // this will yield the erased types which is why the actual lambda - // expression parameters are only used as a fallback. - lambdaParameters - .map(_.getType) - .map(typeInfoCalc.fullName) + case _: SuperExpr => someWithDotSuffix(NameConstants.Super) - if paramTypesList.sizeIs != lambdaParameters.size then - logger.debug( - s"Found different number lambda params and param types for $expr. Some parameters will be missing." - ) + case _: ThisExpr => someWithDotSuffix(NameConstants.This) - val parameterNodes = lambdaParameters - .zip(paramTypesList) - .zipWithIndex - .map { case ((param, maybeType), idx) => - val name = param.getNameAsString - val typeFullName = maybeType.getOrElse(TypeConstants.Any) - val code = s"$typeFullName $name" - val evalStrat = - if param.getType.isPrimitiveType then EvaluationStrategies.BY_VALUE - else EvaluationStrategies.BY_SHARING - val paramNode = NewMethodParameterIn() - .name(name) - .index(idx + 1) - .order(idx + 1) - .code(code) - .evaluationStrategy(evalStrat) - .typeFullName(typeFullName) - .lineNumber(line(expr)) - .columnNumber(column(expr)) - typeInfoCalc.registerType(typeFullName) - paramNode - } + case scopeMethodCall: MethodCallExpr => + codePrefixForMethodCall(scopeMethodCall) match + case "" => Some("") + case prefix => + val argumentsCode = getArgumentCodeString(scopeMethodCall.getArguments) + someWithDotSuffix( + s"$prefix${scopeMethodCall.getNameAsString}($argumentsCode)" + ) - parameterNodes.foreach { paramNode => - scope.addParameter(paramNode) + case objectCreationExpr: ObjectCreationExpr => + val typeName = objectCreationExpr.getTypeAsString + val argumentsString = getArgumentCodeString(objectCreationExpr.getArguments) + someWithDotSuffix(s"new $typeName($argumentsString)") + + case _ => None + + private def codePrefixForMethodCall(call: MethodCallExpr): String = + tryWithSafeStackOverflow(call.resolve()) match + case Success(resolvedCall) => + call.getScope.toScala + .flatMap(codeForScopeExpr(_, resolvedCall.isStatic)) + .getOrElse(if resolvedCall.isStatic then "" else s"${NameConstants.This}.") + + case _ => + // If the call is unresolvable, we cannot make a good guess about what the prefix should be + "" + + private def createObjectNode( + typeFullName: Option[String], + call: MethodCallExpr, + dispatchType: String + ): Option[NewIdentifier] = + val maybeScope = call.getScope.toScala + + Option.when(maybeScope.isDefined || dispatchType == DispatchTypes.DYNAMIC_DISPATCH) { + val name = maybeScope.map(_.toString).getOrElse(NameConstants.This) + identifierNode(call, name, name, typeFullName.getOrElse("ANY")) + } + + private def nextLambdaName(): String = + s"$LambdaNamePrefix${lambdaKeyPool.next}" + + private def nextIndexName(): String = + s"$IndexNamePrefix${indexKeyPool.next}" + + private def nextIterableName(): String = + s"$IterableNamePrefix${iterableKeyPool.next}" + + private def genericParamTypeMapForLambda(expectedType: ExpectedType): ResolvedTypeParametersMap = + expectedType.resolvedType + // This should always be true for correct code + .collect { case r: ResolvedReferenceType => r } + .map(_.typeParametersMap()) + .getOrElse(new ResolvedTypeParametersMap.Builder().build()) + + private def buildParamListForLambda( + expr: LambdaExpr, + maybeBoundMethod: Option[ResolvedMethodDeclaration], + expectedTypeParamTypes: ResolvedTypeParametersMap + ): Seq[Ast] = + val lambdaParameters = expr.getParameters.asScala.toList + val paramTypesList = maybeBoundMethod match + case Some(resolvedMethod) => + val resolvedParameters = + (0 until resolvedMethod.getNumberOfParams).map(resolvedMethod.getParam) + + // Substitute generic typeParam with the expected type if it can be found; leave unchanged otherwise. + resolvedParameters.map(param => Try(param.getType)).map { + case Success(resolvedType: ResolvedTypeVariable) => + val typ = expectedTypeParamTypes.getValue(resolvedType.asTypeParameter) + typeInfoCalc.fullName(typ) + + case Success(resolvedType) => typeInfoCalc.fullName(resolvedType) + + case Failure(_) => None + } + + case None => + // Unless types are explicitly specified in the lambda definition, + // this will yield the erased types which is why the actual lambda + // expression parameters are only used as a fallback. + lambdaParameters + .map(_.getType) + .map(typeInfoCalc.fullName) + + if paramTypesList.sizeIs != lambdaParameters.size then + logger.debug( + s"Found different number lambda params and param types for $expr. Some parameters will be missing." + ) + + val parameterNodes = lambdaParameters + .zip(paramTypesList) + .zipWithIndex + .map { case ((param, maybeType), idx) => + val name = param.getNameAsString + val typeFullName = maybeType.getOrElse(TypeConstants.Any) + val code = s"$typeFullName $name" + val evalStrat = + if param.getType.isPrimitiveType then EvaluationStrategies.BY_VALUE + else EvaluationStrategies.BY_SHARING + val paramNode = NewMethodParameterIn() + .name(name) + .index(idx + 1) + .order(idx + 1) + .code(code) + .evaluationStrategy(evalStrat) + .typeFullName(typeFullName) + .lineNumber(line(expr)) + .columnNumber(column(expr)) + typeInfoCalc.registerType(typeFullName) + paramNode } - parameterNodes.map(Ast(_)) - end buildParamListForLambda - - private def getLambdaReturnType( - maybeResolvedLambdaType: Option[ResolvedType], - maybeBoundMethod: Option[ResolvedMethodDeclaration], - expectedTypeParamTypes: ResolvedTypeParametersMap - ): Option[String] = - val maybeBoundMethodReturnType = maybeBoundMethod.flatMap { boundMethod => - Try(boundMethod.getReturnType).collect { - case returnType: ResolvedTypeVariable => - expectedTypeParamTypes.getValue(returnType.asTypeParameter) - case other => other - }.toOption + parameterNodes.foreach { paramNode => + scope.addParameter(paramNode) + } + + parameterNodes.map(Ast(_)) + end buildParamListForLambda + + private def getLambdaReturnType( + maybeResolvedLambdaType: Option[ResolvedType], + maybeBoundMethod: Option[ResolvedMethodDeclaration], + expectedTypeParamTypes: ResolvedTypeParametersMap + ): Option[String] = + val maybeBoundMethodReturnType = maybeBoundMethod.flatMap { boundMethod => + Try(boundMethod.getReturnType).collect { + case returnType: ResolvedTypeVariable => + expectedTypeParamTypes.getValue(returnType.asTypeParameter) + case other => other + }.toOption + } + + val returnType = maybeBoundMethodReturnType.orElse(maybeResolvedLambdaType) + returnType.flatMap(typeInfoCalc.fullName) + + private def closureBindingsForCapturedNodes(lambdaMethodName: String): List[ClosureBindingEntry] = + scope.capturedVariables.map { capturedNode => + val closureBindingId = s"$filename:$lambdaMethodName:${capturedNode.name}" + val closureBindingNode = + newClosureBindingNode( + closureBindingId, + capturedNode.name, + EvaluationStrategies.BY_SHARING + ) + passes.ClosureBindingEntry(capturedNode, closureBindingNode) + } + + private def localsForCapturedNodes(closureBindingEntries: List[ClosureBindingEntry]) + : List[NewLocal] = + val localsForCaptured = + closureBindingEntries.map { case ClosureBindingEntry(node, binding) => + val local = NewLocal() + .name(node.name) + .code(node.name) + .closureBindingId(binding.closureBindingId) + .typeFullName(node.typeFullName) + local } - - val returnType = maybeBoundMethodReturnType.orElse(maybeResolvedLambdaType) - returnType.flatMap(typeInfoCalc.fullName) - - private def closureBindingsForCapturedNodes(lambdaMethodName: String) - : List[ClosureBindingEntry] = - scope.capturedVariables.map { capturedNode => - val closureBindingId = s"$filename:$lambdaMethodName:${capturedNode.name}" - val closureBindingNode = - newClosureBindingNode( - closureBindingId, - capturedNode.name, - EvaluationStrategies.BY_SHARING - ) - passes.ClosureBindingEntry(capturedNode, closureBindingNode) + localsForCaptured.foreach { local => scope.addLocal(local) } + localsForCaptured + + private def astForLambdaBody( + body: Statement, + localsForCapturedVars: Seq[NewLocal], + returnType: Option[String] + ): Ast = + body match + case block: BlockStmt => + astForBlockStatement(block, prefixAsts = localsForCapturedVars.map(Ast(_))) + + case stmt => + val blockAst = Ast(NewBlock().lineNumber(line(body))) + val bodyAst = if returnType.contains(TypeConstants.Void) then + astsForStatement(stmt) + else + val returnNode = + NewReturn() + .code(s"return ${body.toString}") + .lineNumber(line(body)) + val returnArgs = astsForStatement(stmt) + Seq(returnAst(returnNode, returnArgs)) + + blockAst + .withChildren(localsForCapturedVars.map(Ast(_))) + .withChildren(bodyAst) + + private def lambdaMethodSignature(returnType: Option[String], parameters: Seq[Ast]): String = + val maybeParameterTypes = toOptionList(parameters.map(_.rootType)) + val containsEmptyType = + maybeParameterTypes.exists(_.contains(ParameterDefaults.TypeFullName)) + + (returnType, maybeParameterTypes) match + case (Some(returnTpe), Some(parameterTpes)) if !containsEmptyType => + composeMethodLikeSignature(returnTpe, parameterTpes) + + case _ => composeUnresolvedSignature(parameters.size) + + private def createLambdaMethodNode( + lambdaName: String, + parameters: Seq[Ast], + returnType: Option[String] + ): NewMethod = + val enclosingTypeName = + scope.enclosingTypeDeclFullName.getOrElse(Defines.UnresolvedNamespace) + val signature = lambdaMethodSignature(returnType, parameters) + val lambdaFullName = composeMethodFullName(enclosingTypeName, lambdaName, signature) + + NewMethod() + .name(lambdaName) + .fullName(lambdaFullName) + .signature(signature) + .filename(filename) + .code("") + + private def addClosureBindingsToDiffGraph( + bindingEntries: Iterable[ClosureBindingEntry], + methodRef: NewMethodRef + ): Unit = + bindingEntries.foreach { case ClosureBindingEntry(nodeTypeInfo, closureBinding) => + diffGraph.addNode(closureBinding) + diffGraph.addEdge(closureBinding, nodeTypeInfo.node, EdgeTypes.REF) + diffGraph.addEdge(methodRef, closureBinding, EdgeTypes.CAPTURE) + } + + private def createAndPushLambdaMethod( + expr: LambdaExpr, + lambdaMethodName: String, + implementedInfo: LambdaImplementedInfo, + localsForCaptured: Seq[NewLocal], + expectedLambdaType: ExpectedType + ): NewMethod = + val implementedMethod = implementedInfo.implementedMethod + val implementedInterface = implementedInfo.implementedInterface + + // We need to get this information from the expected type as the JavaParser + // symbol solver returns the erased types when resolving the lambda itself. + val expectedTypeParamTypes = genericParamTypeMapForLambda(expectedLambdaType) + val parametersWithoutThis = + buildParamListForLambda(expr, implementedMethod, expectedTypeParamTypes) + + val returnType = + getLambdaReturnType(implementedInterface, implementedMethod, expectedTypeParamTypes) + + val lambdaMethodBody = astForLambdaBody(expr.getBody, localsForCaptured, returnType) + + val thisParam = lambdaMethodBody.nodes + .collect { case identifier: NewIdentifier => identifier } + .find { identifier => + identifier.name == NameConstants.This || identifier.name == NameConstants.Super } - - private def localsForCapturedNodes(closureBindingEntries: List[ClosureBindingEntry]) - : List[NewLocal] = - val localsForCaptured = - closureBindingEntries.map { case ClosureBindingEntry(node, binding) => - val local = NewLocal() - .name(node.name) - .code(node.name) - .closureBindingId(binding.closureBindingId) - .typeFullName(node.typeFullName) - local - } - localsForCaptured.foreach { local => scope.addLocal(local) } - localsForCaptured - - private def astForLambdaBody( - body: Statement, - localsForCapturedVars: Seq[NewLocal], - returnType: Option[String] - ): Ast = - body match - case block: BlockStmt => - astForBlockStatement(block, prefixAsts = localsForCapturedVars.map(Ast(_))) - - case stmt => - val blockAst = Ast(NewBlock().lineNumber(line(body))) - val bodyAst = if returnType.contains(TypeConstants.Void) then - astsForStatement(stmt) - else - val returnNode = - NewReturn() - .code(s"return ${body.toString}") - .lineNumber(line(body)) - val returnArgs = astsForStatement(stmt) - Seq(returnAst(returnNode, returnArgs)) - - blockAst - .withChildren(localsForCapturedVars.map(Ast(_))) - .withChildren(bodyAst) - - private def lambdaMethodSignature(returnType: Option[String], parameters: Seq[Ast]): String = - val maybeParameterTypes = toOptionList(parameters.map(_.rootType)) - val containsEmptyType = - maybeParameterTypes.exists(_.contains(ParameterDefaults.TypeFullName)) - - (returnType, maybeParameterTypes) match - case (Some(returnTpe), Some(parameterTpes)) if !containsEmptyType => - composeMethodLikeSignature(returnTpe, parameterTpes) - - case _ => composeUnresolvedSignature(parameters.size) - - private def createLambdaMethodNode( - lambdaName: String, - parameters: Seq[Ast], - returnType: Option[String] - ): NewMethod = - val enclosingTypeName = - scope.enclosingTypeDeclFullName.getOrElse(Defines.UnresolvedNamespace) - val signature = lambdaMethodSignature(returnType, parameters) - val lambdaFullName = composeMethodFullName(enclosingTypeName, lambdaName, signature) - - NewMethod() - .name(lambdaName) - .fullName(lambdaFullName) - .signature(signature) - .filename(filename) - .code("") - - private def addClosureBindingsToDiffGraph( - bindingEntries: Iterable[ClosureBindingEntry], - methodRef: NewMethodRef - ): Unit = - bindingEntries.foreach { case ClosureBindingEntry(nodeTypeInfo, closureBinding) => - diffGraph.addNode(closureBinding) - diffGraph.addEdge(closureBinding, nodeTypeInfo.node, EdgeTypes.REF) - diffGraph.addEdge(methodRef, closureBinding, EdgeTypes.CAPTURE) + .map { _ => + val typeFullName = scope.enclosingTypeDeclFullName + Ast(thisNodeForMethod(typeFullName, line(expr))) } - - private def createAndPushLambdaMethod( - expr: LambdaExpr, - lambdaMethodName: String, - implementedInfo: LambdaImplementedInfo, - localsForCaptured: Seq[NewLocal], - expectedLambdaType: ExpectedType - ): NewMethod = - val implementedMethod = implementedInfo.implementedMethod - val implementedInterface = implementedInfo.implementedInterface - - // We need to get this information from the expected type as the JavaParser - // symbol solver returns the erased types when resolving the lambda itself. - val expectedTypeParamTypes = genericParamTypeMapForLambda(expectedLambdaType) - val parametersWithoutThis = - buildParamListForLambda(expr, implementedMethod, expectedTypeParamTypes) - - val returnType = - getLambdaReturnType(implementedInterface, implementedMethod, expectedTypeParamTypes) - - val lambdaMethodBody = astForLambdaBody(expr.getBody, localsForCaptured, returnType) - - val thisParam = lambdaMethodBody.nodes - .collect { case identifier: NewIdentifier => identifier } - .find { identifier => - identifier.name == NameConstants.This || identifier.name == NameConstants.Super - } - .map { _ => - val typeFullName = scope.enclosingTypeDeclFullName - Ast(thisNodeForMethod(typeFullName, line(expr))) - } - .toList - - val parameters = thisParam ++ parametersWithoutThis - - val lambdaMethodNode = - createLambdaMethodNode(lambdaMethodName, parametersWithoutThis, returnType) - val returnNode = newMethodReturnNode( - returnType.getOrElse(TypeConstants.Any), - None, - line(expr), - column(expr) + .toList + + val parameters = thisParam ++ parametersWithoutThis + + val lambdaMethodNode = + createLambdaMethodNode(lambdaMethodName, parametersWithoutThis, returnType) + val returnNode = newMethodReturnNode( + returnType.getOrElse(TypeConstants.Any), + None, + line(expr), + column(expr) + ) + val virtualModifier = Some(newModifierNode(ModifierTypes.VIRTUAL)) + val staticModifier = Option.when(thisParam.isEmpty)(newModifierNode(ModifierTypes.STATIC)) + val privateModifier = Some(newModifierNode(ModifierTypes.PRIVATE)) + + val modifiers = List(virtualModifier, staticModifier, privateModifier).flatten.map(Ast(_)) + + val lambdaParameterNamesToNodes = + parameters + .flatMap(_.root) + .collect { case param: NewMethodParameterIn => param } + .map { param => param.name -> param } + .toMap + + val identifiersMatchingParams = lambdaMethodBody.nodes + .collect { case identifier: NewIdentifier => identifier } + .filter { identifier => lambdaParameterNamesToNodes.contains(identifier.name) } + + val lambdaMethodAstWithoutRefs = + Ast(lambdaMethodNode) + .withChildren(parameters) + .withChild(lambdaMethodBody) + .withChild(Ast(returnNode)) + .withChildren(modifiers) + + val lambdaMethodAst = + identifiersMatchingParams.foldLeft(lambdaMethodAstWithoutRefs)((ast, identifier) => + ast.withRefEdge(identifier, lambdaParameterNamesToNodes(identifier.name)) ) - val virtualModifier = Some(newModifierNode(ModifierTypes.VIRTUAL)) - val staticModifier = Option.when(thisParam.isEmpty)(newModifierNode(ModifierTypes.STATIC)) - val privateModifier = Some(newModifierNode(ModifierTypes.PRIVATE)) - - val modifiers = List(virtualModifier, staticModifier, privateModifier).flatten.map(Ast(_)) - - val lambdaParameterNamesToNodes = - parameters - .flatMap(_.root) - .collect { case param: NewMethodParameterIn => param } - .map { param => param.name -> param } - .toMap - - val identifiersMatchingParams = lambdaMethodBody.nodes - .collect { case identifier: NewIdentifier => identifier } - .filter { identifier => lambdaParameterNamesToNodes.contains(identifier.name) } - - val lambdaMethodAstWithoutRefs = - Ast(lambdaMethodNode) - .withChildren(parameters) - .withChild(lambdaMethodBody) - .withChild(Ast(returnNode)) - .withChildren(modifiers) - - val lambdaMethodAst = - identifiersMatchingParams.foldLeft(lambdaMethodAstWithoutRefs)((ast, identifier) => - ast.withRefEdge(identifier, lambdaParameterNamesToNodes(identifier.name)) - ) - scope.addLambdaMethod(lambdaMethodAst) - - lambdaMethodNode - end createAndPushLambdaMethod - - private def createAndPushLambdaTypeDecl( - lambdaMethodNode: NewMethod, - implementedInfo: LambdaImplementedInfo - ): NewTypeDecl = - val inheritsFromTypeFullName = - implementedInfo.implementedInterface - .flatMap(typeInfoCalc.fullName) - .orElse(Some(TypeConstants.Object)) - .toList - - typeInfoCalc.registerType(lambdaMethodNode.fullName) - val lambdaTypeDeclNode = - NewTypeDecl() - .fullName(lambdaMethodNode.fullName) - .name(lambdaMethodNode.name) - .inheritsFromTypeFullName(inheritsFromTypeFullName) - scope.addLocalDecl(Ast(lambdaTypeDeclNode)) - - lambdaTypeDeclNode - end createAndPushLambdaTypeDecl - - private def getLambdaImplementedInfo( - expr: LambdaExpr, - expectedType: ExpectedType - ): LambdaImplementedInfo = - val maybeImplementedType = - val maybeResolved = tryWithSafeStackOverflow(expr.calculateResolvedType()) - maybeResolved.toOption - .orElse(expectedType.resolvedType) - .collect { case refType: ResolvedReferenceType => refType } - - val maybeImplementedInterface = maybeImplementedType.flatMap(_.getTypeDeclaration.toScala) - - if maybeImplementedInterface.isEmpty then - val location = s"$filename:${line(expr)}:${column(expr)}" - logger.debug( - s"Could not resolve the interface implemented by a lambda. Type info may be missing: $location. Type info may be missing." - ) + scope.addLambdaMethod(lambdaMethodAst) - val maybeBoundMethod = maybeImplementedInterface.flatMap { interface => - interface.getDeclaredMethods.asScala - .filter(_.isAbstract) - .filterNot { method => - // Filter out java.lang.Object methods re-declared by the interface as these are also considered abstract. - // See https://docs.oracle.com/javase/8/docs/api/java/lang/FunctionalInterface.html for details. - Try(method.getSignature) match - case Success(signature) => ObjectMethodSignatures.contains(signature) - case Failure(_) => - false // If the signature could not be calculated, it's probably not a standard object method. - } - .headOption - } + lambdaMethodNode + end createAndPushLambdaMethod - LambdaImplementedInfo(maybeImplementedType, maybeBoundMethod) - end getLambdaImplementedInfo - - // TODO: All of this will be thrown out, probably - private def astForLambdaExpr(expr: LambdaExpr, expectedType: ExpectedType): Ast = - scope.pushMethodScope(NewMethod(), expectedType) - - val lambdaMethodName = nextLambdaName() - - val closureBindingsForCapturedVars = closureBindingsForCapturedNodes(lambdaMethodName) - val localsForCaptured = localsForCapturedNodes(closureBindingsForCapturedVars) - val implementedInfo = getLambdaImplementedInfo(expr, expectedType) - val lambdaMethodNode = - createAndPushLambdaMethod( - expr, - lambdaMethodName, - implementedInfo, - localsForCaptured, - expectedType - ) - val typeNameLookup = - lambdaMethodNode.fullName.takeWhile(_ != ':').split("\\.").dropRight(1).mkString(".") - val methodRef = - NewMethodRef() - .methodFullName(lambdaMethodNode.fullName) - .typeFullName(lambdaMethodNode.fullName) - .code(lambdaMethodNode.fullName) - .dynamicTypeHintFullName(packagesJarMappings.getOrElse( - typeNameLookup, - mutable.Set.empty - ).toSeq) - - addClosureBindingsToDiffGraph(closureBindingsForCapturedVars, methodRef) - - val interfaceBinding = implementedInfo.implementedMethod.map { implementedMethod => - newBindingNode( - implementedMethod.getName, - lambdaMethodNode.signature, - lambdaMethodNode.fullName - ) - } - - val bindingTable = getLambdaBindingTable( - LambdaBindingInfo( - lambdaMethodNode.fullName, - implementedInfo.implementedInterface, - interfaceBinding - ) - ) - - val lambdaTypeDeclNode = createAndPushLambdaTypeDecl(lambdaMethodNode, implementedInfo) - createBindingNodes(lambdaTypeDeclNode, bindingTable) + private def createAndPushLambdaTypeDecl( + lambdaMethodNode: NewMethod, + implementedInfo: LambdaImplementedInfo + ): NewTypeDecl = + val inheritsFromTypeFullName = + implementedInfo.implementedInterface + .flatMap(typeInfoCalc.fullName) + .orElse(Some(TypeConstants.Object)) + .toList - scope.popScope() - Ast(methodRef) - end astForLambdaExpr - - private def astForLiteralExpr(expr: LiteralExpr): Ast = - val typeFullName = expressionReturnTypeFullName(expr).getOrElse(TypeConstants.Any) - val literalNode = - NewLiteral() - .code(expr.toString) - .lineNumber(line(expr)) - .columnNumber(column(expr)) - .typeFullName(typeFullName) - Ast(literalNode) - - private def getExpectedParamType( - maybeResolvedCall: Try[ResolvedMethodLikeDeclaration], - idx: Int - ): ExpectedType = - maybeResolvedCall.toOption - .map { methodDecl => - val paramCount = methodDecl.getNumberOfParams - - val resolvedType = if idx < paramCount then - Some(methodDecl.getParam(idx).getType) - else if paramCount > 0 && methodDecl.getParam(paramCount - 1).isVariadic then - Some(methodDecl.getParam(paramCount - 1).getType) - else - None - - val typeName = resolvedType.flatMap(typeInfoCalc.fullName) - ExpectedType(typeName, resolvedType) + typeInfoCalc.registerType(lambdaMethodNode.fullName) + val lambdaTypeDeclNode = + NewTypeDecl() + .fullName(lambdaMethodNode.fullName) + .name(lambdaMethodNode.name) + .inheritsFromTypeFullName(inheritsFromTypeFullName) + scope.addLocalDecl(Ast(lambdaTypeDeclNode)) + + lambdaTypeDeclNode + end createAndPushLambdaTypeDecl + + private def getLambdaImplementedInfo( + expr: LambdaExpr, + expectedType: ExpectedType + ): LambdaImplementedInfo = + val maybeImplementedType = + val maybeResolved = tryWithSafeStackOverflow(expr.calculateResolvedType()) + maybeResolved.toOption + .orElse(expectedType.resolvedType) + .collect { case refType: ResolvedReferenceType => refType } + + val maybeImplementedInterface = maybeImplementedType.flatMap(_.getTypeDeclaration.toScala) + + if maybeImplementedInterface.isEmpty then + val location = s"$filename:${line(expr)}:${column(expr)}" + logger.debug( + s"Could not resolve the interface implemented by a lambda. Type info may be missing: $location. Type info may be missing." + ) + + val maybeBoundMethod = maybeImplementedInterface.flatMap { interface => + interface.getDeclaredMethods.asScala + .filter(_.isAbstract) + .filterNot { method => + // Filter out java.lang.Object methods re-declared by the interface as these are also considered abstract. + // See https://docs.oracle.com/javase/8/docs/api/java/lang/FunctionalInterface.html for details. + Try(method.getSignature) match + case Success(signature) => ObjectMethodSignatures.contains(signature) + case Failure(_) => + false // If the signature could not be calculated, it's probably not a standard object method. } - .getOrElse(ExpectedType.empty) - - private def dispatchTypeForCall( - maybeDecl: Try[ResolvedMethodDeclaration], - maybeScope: Option[Expression] - ): String = - maybeScope match - case Some(_: SuperExpr) => - DispatchTypes.STATIC_DISPATCH - case _ => - maybeDecl match - case Success(decl) => - if decl.isStatic then DispatchTypes.STATIC_DISPATCH - else DispatchTypes.DYNAMIC_DISPATCH - - case _ => - DispatchTypes.DYNAMIC_DISPATCH - - private def targetTypeForCall(callExpr: MethodCallExpr): Option[String] = - val maybeType = callExpr.getScope.toScala match - case Some(callScope: ThisExpr) => - expressionReturnTypeFullName(callScope) - .orElse(scope.enclosingTypeDeclFullName) - - case Some(callScope: SuperExpr) => - expressionReturnTypeFullName(callScope) - .orElse(scope.enclosingTypeDecl.flatMap(_.inheritsFromTypeFullName.headOption)) - - case Some(scope) => expressionReturnTypeFullName(scope) - - case None => - tryWithSafeStackOverflow(callExpr.resolve()).toOption - .flatMap { methodDeclOption => - if methodDeclOption.isStatic then - typeInfoCalc.fullName(methodDeclOption.declaringType()) - else scope.enclosingTypeDeclFullName - } - .orElse(scope.enclosingTypeDeclFullName) - - maybeType.map(typeInfoCalc.registerType) - end targetTypeForCall - - private def argAstsForCall( - call: Node, - tryResolvedDecl: Try[ResolvedMethodLikeDeclaration], - args: NodeList[Expression] - ): Seq[Ast] = - val hasVariadicParameter = tryResolvedDecl.map(_.hasVariadicParameter).getOrElse(false) - val paramCount = tryResolvedDecl.map(_.getNumberOfParams).getOrElse(-1) - - val argsAsts = args.asScala.zipWithIndex.flatMap { case (arg, idx) => - val expectedType = getExpectedParamType(tryResolvedDecl, idx) - astsForExpression(arg, expectedType) - }.toList - - tryResolvedDecl match - case Success(_) if hasVariadicParameter => - val expectedVariadicTypeFullName = - getExpectedParamType(tryResolvedDecl, paramCount - 1).fullName - val (regularArgs, varargs) = argsAsts.splitAt(paramCount - 1) - val arrayInitializer = newOperatorCallNode( - Operators.arrayInitializer, - Operators.arrayInitializer, - expectedVariadicTypeFullName, - line(call), - column(call) - ) + .headOption + } - val arrayInitializerAst = callAst(arrayInitializer, varargs) + LambdaImplementedInfo(maybeImplementedType, maybeBoundMethod) + end getLambdaImplementedInfo - regularArgs ++ Seq(arrayInitializerAst) + // TODO: All of this will be thrown out, probably + private def astForLambdaExpr(expr: LambdaExpr, expectedType: ExpectedType): Ast = + scope.pushMethodScope(NewMethod(), expectedType) - case _ => argsAsts - end argAstsForCall + val lambdaMethodName = nextLambdaName() - private def getArgumentCodeString(args: NodeList[Expression]): String = - args.asScala - .map { - case _: LambdaExpr => "" - case other => other.toString - } - .mkString(", ") - - private def astForMethodCall(call: MethodCallExpr, expectedReturnType: ExpectedType): Ast = - val maybeResolvedCall = tryWithSafeStackOverflow(call.resolve()) - val argumentAsts = argAstsForCall(call, maybeResolvedCall, call.getArguments) + val closureBindingsForCapturedVars = closureBindingsForCapturedNodes(lambdaMethodName) + val localsForCaptured = localsForCapturedNodes(closureBindingsForCapturedVars) + val implementedInfo = getLambdaImplementedInfo(expr, expectedType) + val lambdaMethodNode = + createAndPushLambdaMethod( + expr, + lambdaMethodName, + implementedInfo, + localsForCaptured, + expectedType + ) + val typeNameLookup = + lambdaMethodNode.fullName.takeWhile(_ != ':').split("\\.").dropRight(1).mkString(".") + val methodRef = + NewMethodRef() + .methodFullName(lambdaMethodNode.fullName) + .typeFullName(lambdaMethodNode.fullName) + .code(lambdaMethodNode.fullName) + .dynamicTypeHintFullName(packagesJarMappings.getOrElse( + typeNameLookup, + mutable.Set.empty + ).toSeq) - val expressionTypeFullName = - expressionReturnTypeFullName(call).orElse(expectedReturnType.fullName) + addClosureBindingsToDiffGraph(closureBindingsForCapturedVars, methodRef) - val argumentTypes = argumentTypesForMethodLike(maybeResolvedCall) - val returnType = maybeResolvedCall - .map { resolvedCall => - typeInfoCalc.fullName(resolvedCall.getReturnType, ResolvedTypeParametersMap.empty()) - } - .toOption - .flatten - .orElse(expressionTypeFullName) - val dispatchType = dispatchTypeForCall(maybeResolvedCall, call.getScope.toScala) - - val receiverTypeOption = targetTypeForCall(call) - val scopeAsts = call.getScope.toScala match - case Some(scope) => astsForExpression(scope, ExpectedType(receiverTypeOption)) - - case None => - val objectNode = - createObjectNode(receiverTypeOption, call, dispatchType) - for - obj <- objectNode - thisParam <- scope.lookupVariable(NameConstants.This).variableNode - do diffGraph.addEdge(obj, thisParam, EdgeTypes.REF) - objectNode.map(Ast(_)).toList - - val receiverType = - receiverTypeOption.orElse(scopeAsts.rootType).filter(_ != TypeConstants.Any) - - val argumentsCode = getArgumentCodeString(call.getArguments) - val codePrefix = codePrefixForMethodCall(call) - val callCode = s"$codePrefix${call.getNameAsString}($argumentsCode)" - - val callName = call.getNameAsString - val namespace = receiverType.getOrElse(Defines.UnresolvedNamespace) - val signature = composeSignature(returnType, argumentTypes, argumentAsts.size) - val methodFullName = composeMethodFullName(namespace, callName, signature) - val typeFullNameStr = expressionTypeFullName.getOrElse(TypeConstants.Any) - val callRoot = NewCall() - .name(callName) - .methodFullName(methodFullName) - .signature(signature) - .code(callCode) - .dispatchType(dispatchType) - .lineNumber(line(call)) - .columnNumber(column(call)) - .typeFullName(typeFullNameStr) - callRoot.dynamicTypeHintFullName( - packagesJarMappings - .getOrElse( - methodFullName.takeWhile(_ != ':').split("\\.").dropRight(1).mkString("."), - mutable.Set.empty - ) - .toSeq + val interfaceBinding = implementedInfo.implementedMethod.map { implementedMethod => + newBindingNode( + implementedMethod.getName, + lambdaMethodNode.signature, + lambdaMethodNode.fullName ) - callAst(callRoot, argumentAsts, scopeAsts.headOption) - end astForMethodCall + } + + val bindingTable = getLambdaBindingTable( + LambdaBindingInfo( + lambdaMethodNode.fullName, + implementedInfo.implementedInterface, + interfaceBinding + ) + ) + + val lambdaTypeDeclNode = createAndPushLambdaTypeDecl(lambdaMethodNode, implementedInfo) + createBindingNodes(lambdaTypeDeclNode, bindingTable) + + scope.popScope() + Ast(methodRef) + end astForLambdaExpr + + private def astForLiteralExpr(expr: LiteralExpr): Ast = + val typeFullName = expressionReturnTypeFullName(expr).getOrElse(TypeConstants.Any) + val literalNode = + NewLiteral() + .code(expr.toString) + .lineNumber(line(expr)) + .columnNumber(column(expr)) + .typeFullName(typeFullName) + Ast(literalNode) + + private def getExpectedParamType( + maybeResolvedCall: Try[ResolvedMethodLikeDeclaration], + idx: Int + ): ExpectedType = + maybeResolvedCall.toOption + .map { methodDecl => + val paramCount = methodDecl.getNumberOfParams + + val resolvedType = if idx < paramCount then + Some(methodDecl.getParam(idx).getType) + else if paramCount > 0 && methodDecl.getParam(paramCount - 1).isVariadic then + Some(methodDecl.getParam(paramCount - 1).getType) + else + None - private def astForSuperExpr(superExpr: SuperExpr, expectedType: ExpectedType): Ast = - val typeFullName = - expressionReturnTypeFullName(superExpr) - .orElse(expectedType.fullName) - .getOrElse(TypeConstants.Any) + val typeName = resolvedType.flatMap(typeInfoCalc.fullName) + ExpectedType(typeName, resolvedType) + } + .getOrElse(ExpectedType.empty) + + private def dispatchTypeForCall( + maybeDecl: Try[ResolvedMethodDeclaration], + maybeScope: Option[Expression] + ): String = + maybeScope match + case Some(_: SuperExpr) => + DispatchTypes.STATIC_DISPATCH + case _ => + maybeDecl match + case Success(decl) => + if decl.isStatic then DispatchTypes.STATIC_DISPATCH + else DispatchTypes.DYNAMIC_DISPATCH + + case _ => + DispatchTypes.DYNAMIC_DISPATCH + + private def targetTypeForCall(callExpr: MethodCallExpr): Option[String] = + val maybeType = callExpr.getScope.toScala match + case Some(callScope: ThisExpr) => + expressionReturnTypeFullName(callScope) + .orElse(scope.enclosingTypeDeclFullName) + + case Some(callScope: SuperExpr) => + expressionReturnTypeFullName(callScope) + .orElse(scope.enclosingTypeDecl.flatMap(_.inheritsFromTypeFullName.headOption)) + + case Some(scope) => expressionReturnTypeFullName(scope) + + case None => + tryWithSafeStackOverflow(callExpr.resolve()).toOption + .flatMap { methodDeclOption => + if methodDeclOption.isStatic then + typeInfoCalc.fullName(methodDeclOption.declaringType()) + else scope.enclosingTypeDeclFullName + } + .orElse(scope.enclosingTypeDeclFullName) + + maybeType.map(typeInfoCalc.registerType) + end targetTypeForCall + + private def argAstsForCall( + call: Node, + tryResolvedDecl: Try[ResolvedMethodLikeDeclaration], + args: NodeList[Expression] + ): Seq[Ast] = + val hasVariadicParameter = tryResolvedDecl.map(_.hasVariadicParameter).getOrElse(false) + val paramCount = tryResolvedDecl.map(_.getNumberOfParams).getOrElse(-1) + + val argsAsts = args.asScala.zipWithIndex.flatMap { case (arg, idx) => + val expectedType = getExpectedParamType(tryResolvedDecl, idx) + astsForExpression(arg, expectedType) + }.toList + + tryResolvedDecl match + case Success(_) if hasVariadicParameter => + val expectedVariadicTypeFullName = + getExpectedParamType(tryResolvedDecl, paramCount - 1).fullName + val (regularArgs, varargs) = argsAsts.splitAt(paramCount - 1) + val arrayInitializer = newOperatorCallNode( + Operators.arrayInitializer, + Operators.arrayInitializer, + expectedVariadicTypeFullName, + line(call), + column(call) + ) - typeInfoCalc.registerType(typeFullName) + val arrayInitializerAst = callAst(arrayInitializer, varargs) - val identifier = - identifierNode(superExpr, NameConstants.This, NameConstants.Super, typeFullName) - Ast(identifier) + regularArgs ++ Seq(arrayInitializerAst) - private def astsForParameterList(parameters: NodeList[Parameter]): Seq[Ast] = - parameters.asScala.toList.zipWithIndex.map { case (param, idx) => - astForParameter(param, idx + 1) - } + case _ => argsAsts + end argAstsForCall - private def astForParameter(parameter: Parameter, childNum: Int): Ast = - val maybeArraySuffix = if parameter.isVarArgs then "[]" else "" - val typeFullName = - typeInfoCalc - .fullName(parameter.getType) - .orElse(scope.lookupType(parameter.getTypeAsString)) - .map(_ ++ maybeArraySuffix) - .getOrElse(guessTypeFullName(parameter.getTypeAsString)) - val evalStrat = - if parameter.getType.isPrimitiveType then EvaluationStrategies.BY_VALUE - else EvaluationStrategies.BY_SHARING - typeInfoCalc.registerType(typeFullName) - val parameterNode = NewMethodParameterIn() - .name(parameter.getName.toString) - .code(parameter.toString) - .lineNumber(line(parameter)) - .columnNumber(column(parameter)) - .evaluationStrategy(evalStrat) - .typeFullName(typeFullName) - .index(childNum) - .order(childNum) - val annotationAsts = parameter.getAnnotations.asScala.map(astForAnnotationExpr) - val ast = Ast(parameterNode) + private def getArgumentCodeString(args: NodeList[Expression]): String = + args.asScala + .map { + case _: LambdaExpr => "" + case other => other.toString + } + .mkString(", ") + + private def astForMethodCall(call: MethodCallExpr, expectedReturnType: ExpectedType): Ast = + val maybeResolvedCall = tryWithSafeStackOverflow(call.resolve()) + val argumentAsts = argAstsForCall(call, maybeResolvedCall, call.getArguments) - scope.addParameter(parameterNode) + val expressionTypeFullName = + expressionReturnTypeFullName(call).orElse(expectedReturnType.fullName) - ast.withChildren(annotationAsts) - end astForParameter + val argumentTypes = argumentTypesForMethodLike(maybeResolvedCall) + val returnType = maybeResolvedCall + .map { resolvedCall => + typeInfoCalc.fullName(resolvedCall.getReturnType, ResolvedTypeParametersMap.empty()) + } + .toOption + .flatten + .orElse(expressionTypeFullName) + val dispatchType = dispatchTypeForCall(maybeResolvedCall, call.getScope.toScala) + + val receiverTypeOption = targetTypeForCall(call) + val scopeAsts = call.getScope.toScala match + case Some(scope) => astsForExpression(scope, ExpectedType(receiverTypeOption)) + + case None => + val objectNode = + createObjectNode(receiverTypeOption, call, dispatchType) + for + obj <- objectNode + thisParam <- scope.lookupVariable(NameConstants.This).variableNode + do diffGraph.addEdge(obj, thisParam, EdgeTypes.REF) + objectNode.map(Ast(_)).toList + + val receiverType = + receiverTypeOption.orElse(scopeAsts.rootType).filter(_ != TypeConstants.Any) + + val argumentsCode = getArgumentCodeString(call.getArguments) + val codePrefix = codePrefixForMethodCall(call) + val callCode = s"$codePrefix${call.getNameAsString}($argumentsCode)" + + val callName = call.getNameAsString + val namespace = receiverType.getOrElse(Defines.UnresolvedNamespace) + val signature = composeSignature(returnType, argumentTypes, argumentAsts.size) + val methodFullName = composeMethodFullName(namespace, callName, signature) + val typeFullNameStr = expressionTypeFullName.getOrElse(TypeConstants.Any) + val callRoot = NewCall() + .name(callName) + .methodFullName(methodFullName) + .signature(signature) + .code(callCode) + .dispatchType(dispatchType) + .lineNumber(line(call)) + .columnNumber(column(call)) + .typeFullName(typeFullNameStr) + callRoot.dynamicTypeHintFullName( + packagesJarMappings + .getOrElse( + methodFullName.takeWhile(_ != ':').split("\\.").dropRight(1).mkString("."), + mutable.Set.empty + ) + .toSeq + ) + callAst(callRoot, argumentAsts, scopeAsts.headOption) + end astForMethodCall + + private def astForSuperExpr(superExpr: SuperExpr, expectedType: ExpectedType): Ast = + val typeFullName = + expressionReturnTypeFullName(superExpr) + .orElse(expectedType.fullName) + .getOrElse(TypeConstants.Any) + + typeInfoCalc.registerType(typeFullName) + + val identifier = + identifierNode(superExpr, NameConstants.This, NameConstants.Super, typeFullName) + Ast(identifier) + + private def astsForParameterList(parameters: NodeList[Parameter]): Seq[Ast] = + parameters.asScala.toList.zipWithIndex.map { case (param, idx) => + astForParameter(param, idx + 1) + } + + private def astForParameter(parameter: Parameter, childNum: Int): Ast = + val maybeArraySuffix = if parameter.isVarArgs then "[]" else "" + val typeFullName = + typeInfoCalc + .fullName(parameter.getType) + .orElse(scope.lookupType(parameter.getTypeAsString)) + .map(_ ++ maybeArraySuffix) + .getOrElse(guessTypeFullName(parameter.getTypeAsString)) + val evalStrat = + if parameter.getType.isPrimitiveType then EvaluationStrategies.BY_VALUE + else EvaluationStrategies.BY_SHARING + typeInfoCalc.registerType(typeFullName) + val parameterNode = NewMethodParameterIn() + .name(parameter.getName.toString) + .code(parameter.toString) + .lineNumber(line(parameter)) + .columnNumber(column(parameter)) + .evaluationStrategy(evalStrat) + .typeFullName(typeFullName) + .index(childNum) + .order(childNum) + val annotationAsts = parameter.getAnnotations.asScala.map(astForAnnotationExpr) + val ast = Ast(parameterNode) + + scope.addParameter(parameterNode) + + ast.withChildren(annotationAsts) + end astForParameter end AstCreator diff --git a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/passes/ConfigFileCreationPass.scala b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/passes/ConfigFileCreationPass.scala index d66668d8..2cf5bbbb 100644 --- a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/passes/ConfigFileCreationPass.scala +++ b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/passes/ConfigFileCreationPass.scala @@ -6,44 +6,44 @@ import io.shiftleft.codepropertygraph.Cpg class ConfigFileCreationPass(cpg: Cpg) extends XConfigFileCreationPass(cpg): - override val configFileFilters: List[File => Boolean] = List( - // JAVA_INTERNAL - extensionFilter(".properties"), - // HTML - pathRegexFilter(".*resources/templates.*.html"), - // JSP - extensionFilter(".jsp"), - // Velocity files, see https://velocity.apache.org - extensionFilter(".vm"), - // For Terraform secrets - extensionFilter(".tf"), - extensionFilter(".tfvars"), - // PLAY - pathEndFilter("routes"), - pathEndFilter("application.conf"), - // SERVLET - pathEndFilter("web.xml"), - // JSF - pathEndFilter("faces-config.xml"), - // STRUTS - pathEndFilter("struts.xml"), - // DIRECT WEB REMOTING - pathEndFilter("dwr.xml"), - // MYBATIS - mybatisFilter, - // BUILD SYSTEM - pathEndFilter("build.gradle"), - pathEndFilter("build.gradle.kts"), - // ANDROID - pathEndFilter("AndroidManifest.xml"), - // Bom - pathEndFilter("bom.json"), - pathEndFilter(".cdx.json"), - pathEndFilter("chennai.json"), - extensionFilter(".yml"), - extensionFilter(".yaml") - ) + override val configFileFilters: List[File => Boolean] = List( + // JAVA_INTERNAL + extensionFilter(".properties"), + // HTML + pathRegexFilter(".*resources/templates.*.html"), + // JSP + extensionFilter(".jsp"), + // Velocity files, see https://velocity.apache.org + extensionFilter(".vm"), + // For Terraform secrets + extensionFilter(".tf"), + extensionFilter(".tfvars"), + // PLAY + pathEndFilter("routes"), + pathEndFilter("application.conf"), + // SERVLET + pathEndFilter("web.xml"), + // JSF + pathEndFilter("faces-config.xml"), + // STRUTS + pathEndFilter("struts.xml"), + // DIRECT WEB REMOTING + pathEndFilter("dwr.xml"), + // MYBATIS + mybatisFilter, + // BUILD SYSTEM + pathEndFilter("build.gradle"), + pathEndFilter("build.gradle.kts"), + // ANDROID + pathEndFilter("AndroidManifest.xml"), + // Bom + pathEndFilter("bom.json"), + pathEndFilter(".cdx.json"), + pathEndFilter("chennai.json"), + extensionFilter(".yml"), + extensionFilter(".yaml") + ) - private def mybatisFilter(file: File): Boolean = - file.canonicalPath.contains("batis") && file.extension.contains(".xml") + private def mybatisFilter(file: File): Boolean = + file.canonicalPath.contains("batis") && file.extension.contains(".xml") end ConfigFileCreationPass diff --git a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/passes/JavaTypeHintCallLinker.scala b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/passes/JavaTypeHintCallLinker.scala index afbe9725..397ab1f7 100644 --- a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/passes/JavaTypeHintCallLinker.scala +++ b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/passes/JavaTypeHintCallLinker.scala @@ -10,11 +10,11 @@ import java.util.regex.Pattern class JavaTypeHintCallLinker(cpg: Cpg) extends XTypeHintCallLinker(cpg): - override protected def calls: Iterator[Call] = - cpg.call - .nameNot(".*", ".*") - .filter(c => - calleeNames(c).nonEmpty && c.callee.fullNameNot( - Pattern.quote(Defines.UnresolvedNamespace) + ".*" - ).isEmpty - ) + override protected def calls: Iterator[Call] = + cpg.call + .nameNot(".*", ".*") + .filter(c => + calleeNames(c).nonEmpty && c.callee.fullNameNot( + Pattern.quote(Defines.UnresolvedNamespace) + ".*" + ).isEmpty + ) diff --git a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/passes/JavaTypeRecoveryPass.scala b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/passes/JavaTypeRecoveryPass.scala index 7499821a..7dbf8fc8 100644 --- a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/passes/JavaTypeRecoveryPass.scala +++ b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/passes/JavaTypeRecoveryPass.scala @@ -10,22 +10,22 @@ import overflowdb.BatchedUpdate.DiffGraphBuilder class JavaTypeRecoveryPass(cpg: Cpg, config: XTypeRecoveryConfig = XTypeRecoveryConfig()) extends XTypeRecoveryPass[Method](cpg, config): - override protected def generateRecoveryPass(state: XTypeRecoveryState): XTypeRecovery[Method] = - new JavaTypeRecovery(cpg, state) + override protected def generateRecoveryPass(state: XTypeRecoveryState): XTypeRecovery[Method] = + new JavaTypeRecovery(cpg, state) private class JavaTypeRecovery(cpg: Cpg, state: XTypeRecoveryState) extends XTypeRecovery[Method](cpg, state): - override def compilationUnit: Iterator[Method] = cpg.method.isExternal(false).iterator + override def compilationUnit: Iterator[Method] = cpg.method.isExternal(false).iterator - override def generateRecoveryForCompilationUnitTask( - unit: Method, - builder: DiffGraphBuilder - ): RecoverForXCompilationUnit[Method] = - val newConfig = state.config.copy(enabledDummyTypes = - state.isFinalIteration && state.config.enabledDummyTypes - ) - new RecoverForJavaFile(cpg, unit, builder, state.copy(config = newConfig)) + override def generateRecoveryForCompilationUnitTask( + unit: Method, + builder: DiffGraphBuilder + ): RecoverForXCompilationUnit[Method] = + val newConfig = state.config.copy(enabledDummyTypes = + state.isFinalIteration && state.config.enabledDummyTypes + ) + new RecoverForJavaFile(cpg, unit, builder, state.copy(config = newConfig)) private class RecoverForJavaFile( cpg: Cpg, @@ -34,38 +34,38 @@ private class RecoverForJavaFile( state: XTypeRecoveryState ) extends RecoverForXCompilationUnit[Method](cpg, cu, builder, state): - private def javaNodeToLocalKey(n: AstNode): Option[LocalKey] = n match - case i: Identifier if i.name == "this" && i.code == "super" => Option(LocalVar("super")) - case _ => SBKey.fromNodeToLocalKey(n) + private def javaNodeToLocalKey(n: AstNode): Option[LocalKey] = n match + case i: Identifier if i.name == "this" && i.code == "super" => Option(LocalVar("super")) + case _ => SBKey.fromNodeToLocalKey(n) - override protected val symbolTable = new SymbolTable[LocalKey](javaNodeToLocalKey) + override protected val symbolTable = new SymbolTable[LocalKey](javaNodeToLocalKey) - override protected def isConstructor(c: Call): Boolean = isConstructor(c.name) + override protected def isConstructor(c: Call): Boolean = isConstructor(c.name) - override protected def isConstructor(name: String): Boolean = - !name.isBlank && name.charAt(0).isUpper + override protected def isConstructor(name: String): Boolean = + !name.isBlank && name.charAt(0).isUpper - override protected def postVisitImports(): Unit = - symbolTable.view.foreach { case (k, ts) => - val tss = ts.filterNot(_.startsWith(Defines.UnresolvedNamespace)) - if tss.isEmpty then - symbolTable.remove(k) - else - symbolTable.put(k, tss) - } + override protected def postVisitImports(): Unit = + symbolTable.view.foreach { case (k, ts) => + val tss = ts.filterNot(_.startsWith(Defines.UnresolvedNamespace)) + if tss.isEmpty then + symbolTable.remove(k) + else + symbolTable.put(k, tss) + } - // There seems to be issues with inferring these, often due to situations where super and this are confused on name - // and code properties. - override protected def storeIdentifierTypeInfo(i: Identifier, types: Seq[String]): Unit = - if i.name != "this" then - super.storeIdentifierTypeInfo(i, types) + // There seems to be issues with inferring these, often due to situations where super and this are confused on name + // and code properties. + override protected def storeIdentifierTypeInfo(i: Identifier, types: Seq[String]): Unit = + if i.name != "this" then + super.storeIdentifierTypeInfo(i, types) - override protected def storeCallTypeInfo(c: Call, types: Seq[String]): Unit = - if types.nonEmpty then - state.changesWereMade.compareAndSet(false, true) - val signedTypes = types.map { - case t if t.endsWith(c.signature) => t - case t => s"$t:${c.signature}" - } - builder.setNodeProperty(c, PropertyNames.DYNAMIC_TYPE_HINT_FULL_NAME, signedTypes) + override protected def storeCallTypeInfo(c: Call, types: Seq[String]): Unit = + if types.nonEmpty then + state.changesWereMade.compareAndSet(false, true) + val signedTypes = types.map { + case t if t.endsWith(c.signature) => t + case t => s"$t:${c.signature}" + } + builder.setNodeProperty(c, PropertyNames.DYNAMIC_TYPE_HINT_FULL_NAME, signedTypes) end RecoverForJavaFile diff --git a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/passes/TypeInferencePass.scala b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/passes/TypeInferencePass.scala index f9d84d28..12ec660b 100644 --- a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/passes/TypeInferencePass.scala +++ b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/passes/TypeInferencePass.scala @@ -17,104 +17,104 @@ import io.appthreat.javasrc2cpg.typesolvers.TypeInfoCalculator.TypeConstants class TypeInferencePass(cpg: Cpg) extends ConcurrentWriterCpgPass[Call](cpg): - private val cache = new GuavaCache(CacheBuilder.newBuilder().build[String, Option[Method]]()) - private val resolvedMethodIndex = cpg.method - .filterNot(_.fullName.startsWith(Defines.UnresolvedNamespace)) - .filterNot(_.signature.startsWith(Defines.UnresolvedSignature)) - .groupBy(_.name) - - private case class NameParts(typeDecl: Option[String], signature: String) - - override def generateParts(): Array[Call] = - cpg.call - .filter(_.signature.startsWith(Defines.UnresolvedSignature)) - .filterNot { _.name.startsWith(UnresolvedNamespace) } - .toArray - - private def isMatchingMethod(method: Method, call: Call, callNameParts: NameParts): Boolean = - // An erroneous `this` argument is added for unresolved calls to static methods. - val argSizeMod = - if method.modifier.modifierType.iterator.contains(ModifierTypes.STATIC) then 1 else 0 - lazy val methodNameParts = getNameParts(method.name, method.fullName) - - val parameterSizesMatch = - (method.parameter.size == (call.argument.size - argSizeMod)) - - lazy val argTypesMatch = - doArgumentTypesMatch(method: Method, call: Call, skipCallThis = argSizeMod == 1) - - lazy val typeDeclMatches = (callNameParts.typeDecl == methodNameParts.typeDecl) - - parameterSizesMatch && argTypesMatch && typeDeclMatches - - /** Check if argument types match by comparing exact full names. An argument type of `ANY` - * always matches. - * - * TODO: Take inheritance hierarchies into account - */ - private def doArgumentTypesMatch(method: Method, call: Call, skipCallThis: Boolean): Boolean = - val callArgs = if skipCallThis then call.argument.toList.tail else call.argument.toList - - val hasDifferingArg = method.parameter.zip(callArgs).exists { case (parameter, argument) => - val maybeArgumentType = Option(argument.property(PropertyNames.TypeFullName)) - .map(_.toString()) - .getOrElse(TypeConstants.Any) - - val argMatches = - maybeArgumentType == TypeConstants.Any || maybeArgumentType == parameter.typeFullName - - !argMatches - } - - !hasDifferingArg - - private def getNameParts(name: String, fullName: String): NameParts = - val Array(qualifiedName, signature) = fullName.split(":", 2) - - val typeDeclName = qualifiedName.stripSuffix(name) match - case "" => None - - case typeDeclName => Some(typeDeclName) - - NameParts(typeDeclName, signature) - - private def getReplacementMethod(call: Call): Option[Method] = - val argTypes = - call.argument.flatMap(arg => - Option(arg.property(PropertyNames.TypeFullName)).map(_.toString) - ).mkString(":") - val callKey = - s"${call.methodFullName}:$argTypes" - cache.get(callKey).toScala.getOrElse { - val callNameParts = getNameParts(call.name, call.methodFullName) - resolvedMethodIndex.get(call.name).flatMap { candidateMethods => - val candidateMethodsIter = candidateMethods.iterator - val uniqueMatchingMethod = - candidateMethodsIter.find(isMatchingMethod(_, call, callNameParts)).flatMap { - method => - val otherMatchingMethod = - candidateMethodsIter.find(isMatchingMethod(_, call, callNameParts)) - // Only return a resulting method if exactly one matching method is found. - Option.when(otherMatchingMethod.isEmpty)(method) - } - cache.put(callKey, uniqueMatchingMethod) - uniqueMatchingMethod - } - } - end getReplacementMethod - - override def runOnPart(diffGraph: DiffGraphBuilder, call: Call): Unit = - getReplacementMethod(call).foreach { replacementMethod => - diffGraph.setNodeProperty( - call, - PropertyNames.MethodFullName, - replacementMethod.fullName - ) - diffGraph.setNodeProperty(call, PropertyNames.Signature, replacementMethod.signature) - diffGraph.setNodeProperty( - call, - PropertyNames.TypeFullName, - replacementMethod.methodReturn.typeFullName - ) + private val cache = new GuavaCache(CacheBuilder.newBuilder().build[String, Option[Method]]()) + private val resolvedMethodIndex = cpg.method + .filterNot(_.fullName.startsWith(Defines.UnresolvedNamespace)) + .filterNot(_.signature.startsWith(Defines.UnresolvedSignature)) + .groupBy(_.name) + + private case class NameParts(typeDecl: Option[String], signature: String) + + override def generateParts(): Array[Call] = + cpg.call + .filter(_.signature.startsWith(Defines.UnresolvedSignature)) + .filterNot { _.name.startsWith(UnresolvedNamespace) } + .toArray + + private def isMatchingMethod(method: Method, call: Call, callNameParts: NameParts): Boolean = + // An erroneous `this` argument is added for unresolved calls to static methods. + val argSizeMod = + if method.modifier.modifierType.iterator.contains(ModifierTypes.STATIC) then 1 else 0 + lazy val methodNameParts = getNameParts(method.name, method.fullName) + + val parameterSizesMatch = + (method.parameter.size == (call.argument.size - argSizeMod)) + + lazy val argTypesMatch = + doArgumentTypesMatch(method: Method, call: Call, skipCallThis = argSizeMod == 1) + + lazy val typeDeclMatches = (callNameParts.typeDecl == methodNameParts.typeDecl) + + parameterSizesMatch && argTypesMatch && typeDeclMatches + + /** Check if argument types match by comparing exact full names. An argument type of `ANY` always + * matches. + * + * TODO: Take inheritance hierarchies into account + */ + private def doArgumentTypesMatch(method: Method, call: Call, skipCallThis: Boolean): Boolean = + val callArgs = if skipCallThis then call.argument.toList.tail else call.argument.toList + + val hasDifferingArg = method.parameter.zip(callArgs).exists { case (parameter, argument) => + val maybeArgumentType = Option(argument.property(PropertyNames.TypeFullName)) + .map(_.toString()) + .getOrElse(TypeConstants.Any) + + val argMatches = + maybeArgumentType == TypeConstants.Any || maybeArgumentType == parameter.typeFullName + + !argMatches + } + + !hasDifferingArg + + private def getNameParts(name: String, fullName: String): NameParts = + val Array(qualifiedName, signature) = fullName.split(":", 2) + + val typeDeclName = qualifiedName.stripSuffix(name) match + case "" => None + + case typeDeclName => Some(typeDeclName) + + NameParts(typeDeclName, signature) + + private def getReplacementMethod(call: Call): Option[Method] = + val argTypes = + call.argument.flatMap(arg => + Option(arg.property(PropertyNames.TypeFullName)).map(_.toString) + ).mkString(":") + val callKey = + s"${call.methodFullName}:$argTypes" + cache.get(callKey).toScala.getOrElse { + val callNameParts = getNameParts(call.name, call.methodFullName) + resolvedMethodIndex.get(call.name).flatMap { candidateMethods => + val candidateMethodsIter = candidateMethods.iterator + val uniqueMatchingMethod = + candidateMethodsIter.find(isMatchingMethod(_, call, callNameParts)).flatMap { + method => + val otherMatchingMethod = + candidateMethodsIter.find(isMatchingMethod(_, call, callNameParts)) + // Only return a resulting method if exactly one matching method is found. + Option.when(otherMatchingMethod.isEmpty)(method) + } + cache.put(callKey, uniqueMatchingMethod) + uniqueMatchingMethod } + } + end getReplacementMethod + + override def runOnPart(diffGraph: DiffGraphBuilder, call: Call): Unit = + getReplacementMethod(call).foreach { replacementMethod => + diffGraph.setNodeProperty( + call, + PropertyNames.MethodFullName, + replacementMethod.fullName + ) + diffGraph.setNodeProperty(call, PropertyNames.Signature, replacementMethod.signature) + diffGraph.setNodeProperty( + call, + PropertyNames.TypeFullName, + replacementMethod.methodReturn.typeFullName + ) + } end TypeInferencePass diff --git a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/scope/JavaScopeElement.scala b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/scope/JavaScopeElement.scala index e64b58c6..5ce980da 100644 --- a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/scope/JavaScopeElement.scala +++ b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/scope/JavaScopeElement.scala @@ -17,46 +17,46 @@ import io.appthreat.x2cpg.Ast import io.shiftleft.codepropertygraph.generated.nodes.AstNodeNew trait JavaScopeElement: - private val variables = mutable.Map[String, ScopeVariable]() - private val types = mutable.Map[String, String]() - private var wildcardImports: WildcardImports = NoWildcard + private val variables = mutable.Map[String, ScopeVariable]() + private val types = mutable.Map[String, String]() + private var wildcardImports: WildcardImports = NoWildcard - def addVariableToScope(variable: ScopeVariable): Unit = - variables.put(variable.name, variable) + def addVariableToScope(variable: ScopeVariable): Unit = + variables.put(variable.name, variable) - def lookupVariable(name: String): Option[ScopeVariable] = - variables.get(name) + def lookupVariable(name: String): Option[ScopeVariable] = + variables.get(name) - def addTypeToScope(name: String, typeFullName: String): Unit = - types.put(name, typeFullName) + def addTypeToScope(name: String, typeFullName: String): Unit = + types.put(name, typeFullName) - def lookupType(name: String, includeWildcards: Boolean): Option[String] = - types.get(name) match - case None if includeWildcards => getNameWithWildcardPrefix(name) - case result => result + def lookupType(name: String, includeWildcards: Boolean): Option[String] = + types.get(name) match + case None if includeWildcards => getNameWithWildcardPrefix(name) + case result => result - def getNameWithWildcardPrefix(name: String): Option[String] = - wildcardImports match - case SingleWildcard(prefix) => Some(s"$prefix.$name") + def getNameWithWildcardPrefix(name: String): Option[String] = + wildcardImports match + case SingleWildcard(prefix) => Some(s"$prefix.$name") - case _ => None + case _ => None - def addWildcardImport(prefix: String): Unit = - wildcardImports match - case NoWildcard => wildcardImports = SingleWildcard(prefix) + def addWildcardImport(prefix: String): Unit = + wildcardImports match + case NoWildcard => wildcardImports = SingleWildcard(prefix) - case SingleWildcard(_) => wildcardImports = MultipleWildcards + case SingleWildcard(_) => wildcardImports = MultipleWildcards - case MultipleWildcards => // Already MultipleWildcards, so change nothing - // TODO: Refactor and remove this - def getVariables(): List[ScopeVariable] = variables.values.toList + case MultipleWildcards => // Already MultipleWildcards, so change nothing + // TODO: Refactor and remove this + def getVariables(): List[ScopeVariable] = variables.values.toList end JavaScopeElement private object JavaScopeElement: - sealed trait WildcardImports - case object NoWildcard extends WildcardImports - case class SingleWildcard(prefix: String) extends WildcardImports - case object MultipleWildcards extends WildcardImports + sealed trait WildcardImports + case object NoWildcard extends WildcardImports + case class SingleWildcard(prefix: String) extends WildcardImports + case object MultipleWildcards extends WildcardImports class NamespaceScope(val namespace: NewNamespaceBlock) extends JavaScopeElement with TypeDeclContainer @@ -66,18 +66,18 @@ class BlockScope extends JavaScopeElement class MethodScope(val method: NewMethod, val returnType: ExpectedType) extends JavaScopeElement class TypeDeclScope(val typeDecl: NewTypeDecl) extends JavaScopeElement with TypeDeclContainer: - private val memberInitializers = mutable.ListBuffer[Ast]() + private val memberInitializers = mutable.ListBuffer[Ast]() - // TODO: Refactor and remove this. - def addMemberInitializers(initializers: Seq[Ast]): Unit = - memberInitializers.appendAll(initializers) + // TODO: Refactor and remove this. + def addMemberInitializers(initializers: Seq[Ast]): Unit = + memberInitializers.appendAll(initializers) - // TODO: Refactor and remove this. - def getMemberInitializers(): List[Ast] = - memberInitializers.toList.flatMap { ast => - ast.root match - case Some(root: AstNodeNew) => - Some(ast.subTreeCopy(root)) + // TODO: Refactor and remove this. + def getMemberInitializers(): List[Ast] = + memberInitializers.toList.flatMap { ast => + ast.root match + case Some(root: AstNodeNew) => + Some(ast.subTreeCopy(root)) - case _ => None - } + case _ => None + } diff --git a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/scope/Scope.scala b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/scope/Scope.scala index 3ea7aa2d..10cc5810 100644 --- a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/scope/Scope.scala +++ b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/scope/Scope.scala @@ -32,213 +32,213 @@ case class NodeTypeInfo( isStatic: Boolean = false ) class Scope: - private var scopeStack: List[JavaScopeElement] = Nil + private var scopeStack: List[JavaScopeElement] = Nil - def pushBlockScope(): Unit = - scopeStack = new BlockScope() :: scopeStack + def pushBlockScope(): Unit = + scopeStack = new BlockScope() :: scopeStack - def pushMethodScope(method: NewMethod, returnType: ExpectedType): Unit = - scopeStack = new MethodScope(method, returnType) :: scopeStack + def pushMethodScope(method: NewMethod, returnType: ExpectedType): Unit = + scopeStack = new MethodScope(method, returnType) :: scopeStack - def pushTypeDeclScope(typeDecl: NewTypeDecl): Unit = - scopeStack = new TypeDeclScope(typeDecl) :: scopeStack + def pushTypeDeclScope(typeDecl: NewTypeDecl): Unit = + scopeStack = new TypeDeclScope(typeDecl) :: scopeStack - def pushNamespaceScope(namespace: NewNamespaceBlock): Unit = - scopeStack = new NamespaceScope(namespace) :: scopeStack + def pushNamespaceScope(namespace: NewNamespaceBlock): Unit = + scopeStack = new NamespaceScope(namespace) :: scopeStack - def popScope(): JavaScopeElement = - val scope = scopeStack.head - scopeStack = scopeStack.tail - scope - - def addParameter(parameter: NewMethodParameterIn): Unit = - addVariable(ScopeParameter(parameter)) - - def addLocal(local: NewLocal): Unit = - addVariable(ScopeLocal(local)) - - def addMember(member: NewMember, isStatic: Boolean): Unit = - addVariable(ScopeMember(member, isStatic)) - - def addStaticImport(importNode: NewImport): Unit = - addVariable(ScopeStaticImport(importNode)) - - private def addVariable(variable: ScopeVariable): Unit = - scopeStack.head.addVariableToScope(variable) - - def addType(name: String, typeFullName: String): Unit = - scopeStack.head.addTypeToScope(name, typeFullName) - - def addWildcardImport(prefix: String): Unit = - scopeStack.head.addWildcardImport(prefix) - - def lookupVariable(name: String): VariableLookupResult = - scopeStack.takeUntil(_.lookupVariable(name).isDefined) match - case Nil => NotInScope - - case foundSubstack => - // Know the last will contain the lookup variable since that is the condition to terminate - // the takeUntil - val variable = foundSubstack.last.lookupVariable(name).get - lookupResultFromFound(foundSubstack, variable) - - def lookupType(simpleName: String): Option[String] = - lookupType(simpleName, includeWildcards = true) - - private def lookupType(simpleName: String, includeWildcards: Boolean): Option[String] = - scopeStack.iterator - .map(_.lookupType(simpleName, includeWildcards)) - .collectFirst { case Some(typeFullName) => typeFullName } - - def lookupVariableOrType(name: String): Option[String] = - lookupVariable(name).typeFullName.orElse(lookupType(name, includeWildcards = false)) - - private def lookupResultFromFound( - foundSubstack: List[JavaScopeElement], - variable: ScopeVariable - ): VariableLookupResult = - val typeDecls = foundSubstack.collect { case td: TypeDeclScope => td } - - val isCaptureChain = - typeDecls.size > 1 || (typeDecls.nonEmpty && !foundSubstack.last.isInstanceOf[ - TypeDeclScope - ]) - - if isCaptureChain then - CapturedVariable(typeDecls.map(_.typeDecl), variable) - else - SimpleVariable(variable) - - def enclosingTypeDecl: Option[NewTypeDecl] = - scopeStack.collectFirst { case typeDeclScope: TypeDeclScope => - typeDeclScope.typeDecl - } - - def enclosingTypeDeclFullName: Option[String] = enclosingTypeDecl.map(_.fullName) - - def enclosingMethodReturnType: Option[ExpectedType] = - scopeStack.collectFirst { case methodScope: MethodScope => - methodScope.returnType - } - - def addLocalDecl(decl: Ast): Unit = - scopeStack.collectFirst { case typeDeclContainer: TypeDeclContainer => - typeDeclContainer.registerTypeDecl(decl) - } - - // TODO: The below section of todos are all methods that have been added for simple compatibility with the old - // scope. The plan is to refactor the code to handle these directly in the AstCreator to make the code easier - // to reason about, so these should be removed when that happens. - - // TODO: Refactor and remove this - def addMemberInitializers(initializers: Seq[Ast]): Unit = - scopeStack.collectFirst { case typeDeclScope: TypeDeclScope => - typeDeclScope.addMemberInitializers(initializers) - } - - // TODO: Refactor and remove this - def memberInitializers: List[Ast] = - scopeStack - .collectFirst { case typeDeclScope: TypeDeclScope => - typeDeclScope.getMemberInitializers() - } - .getOrElse(Nil) - - // TODO: Refactor and remove this - def localDeclsInScope: List[Ast] = - scopeStack - .collectFirst { case typeDeclContainer: TypeDeclContainer => - typeDeclContainer.registeredTypeDecls - } - .getOrElse(Nil) - - // TODO: Refactor and remove this - def lambdaMethodsInScope: List[Ast] = - scopeStack - .collectFirst { case typeDeclContainer: TypeDeclContainer => - typeDeclContainer.registeredLambdaMethods - } - .getOrElse(Nil) - - // TODO: Refactor and remove this - def addLambdaMethod(method: Ast): Unit = - scopeStack.collectFirst { case typeDeclContainer: TypeDeclContainer => - typeDeclContainer.registerLambdaMethod(method) - } - - // TODO: Refactor and remove this - def capturedVariables: List[ScopeVariable] = - scopeStack - .flatMap(_.getVariables()) - .collect { - case local: ScopeLocal => local - - case parameter: ScopeParameter if parameter.name != NameConstants.This => parameter - } + def popScope(): JavaScopeElement = + val scope = scopeStack.head + scopeStack = scopeStack.tail + scope + + def addParameter(parameter: NewMethodParameterIn): Unit = + addVariable(ScopeParameter(parameter)) + + def addLocal(local: NewLocal): Unit = + addVariable(ScopeLocal(local)) + + def addMember(member: NewMember, isStatic: Boolean): Unit = + addVariable(ScopeMember(member, isStatic)) + + def addStaticImport(importNode: NewImport): Unit = + addVariable(ScopeStaticImport(importNode)) + + private def addVariable(variable: ScopeVariable): Unit = + scopeStack.head.addVariableToScope(variable) + + def addType(name: String, typeFullName: String): Unit = + scopeStack.head.addTypeToScope(name, typeFullName) + + def addWildcardImport(prefix: String): Unit = + scopeStack.head.addWildcardImport(prefix) + + def lookupVariable(name: String): VariableLookupResult = + scopeStack.takeUntil(_.lookupVariable(name).isDefined) match + case Nil => NotInScope + + case foundSubstack => + // Know the last will contain the lookup variable since that is the condition to terminate + // the takeUntil + val variable = foundSubstack.last.lookupVariable(name).get + lookupResultFromFound(foundSubstack, variable) + + def lookupType(simpleName: String): Option[String] = + lookupType(simpleName, includeWildcards = true) + + private def lookupType(simpleName: String, includeWildcards: Boolean): Option[String] = + scopeStack.iterator + .map(_.lookupType(simpleName, includeWildcards)) + .collectFirst { case Some(typeFullName) => typeFullName } + + def lookupVariableOrType(name: String): Option[String] = + lookupVariable(name).typeFullName.orElse(lookupType(name, includeWildcards = false)) + + private def lookupResultFromFound( + foundSubstack: List[JavaScopeElement], + variable: ScopeVariable + ): VariableLookupResult = + val typeDecls = foundSubstack.collect { case td: TypeDeclScope => td } + + val isCaptureChain = + typeDecls.size > 1 || (typeDecls.nonEmpty && !foundSubstack.last.isInstanceOf[ + TypeDeclScope + ]) + + if isCaptureChain then + CapturedVariable(typeDecls.map(_.typeDecl), variable) + else + SimpleVariable(variable) + + def enclosingTypeDecl: Option[NewTypeDecl] = + scopeStack.collectFirst { case typeDeclScope: TypeDeclScope => + typeDeclScope.typeDecl + } + + def enclosingTypeDeclFullName: Option[String] = enclosingTypeDecl.map(_.fullName) + + def enclosingMethodReturnType: Option[ExpectedType] = + scopeStack.collectFirst { case methodScope: MethodScope => + methodScope.returnType + } + + def addLocalDecl(decl: Ast): Unit = + scopeStack.collectFirst { case typeDeclContainer: TypeDeclContainer => + typeDeclContainer.registerTypeDecl(decl) + } + + // TODO: The below section of todos are all methods that have been added for simple compatibility with the old + // scope. The plan is to refactor the code to handle these directly in the AstCreator to make the code easier + // to reason about, so these should be removed when that happens. + + // TODO: Refactor and remove this + def addMemberInitializers(initializers: Seq[Ast]): Unit = + scopeStack.collectFirst { case typeDeclScope: TypeDeclScope => + typeDeclScope.addMemberInitializers(initializers) + } + + // TODO: Refactor and remove this + def memberInitializers: List[Ast] = + scopeStack + .collectFirst { case typeDeclScope: TypeDeclScope => + typeDeclScope.getMemberInitializers() + } + .getOrElse(Nil) + + // TODO: Refactor and remove this + def localDeclsInScope: List[Ast] = + scopeStack + .collectFirst { case typeDeclContainer: TypeDeclContainer => + typeDeclContainer.registeredTypeDecls + } + .getOrElse(Nil) + + // TODO: Refactor and remove this + def lambdaMethodsInScope: List[Ast] = + scopeStack + .collectFirst { case typeDeclContainer: TypeDeclContainer => + typeDeclContainer.registeredLambdaMethods + } + .getOrElse(Nil) + + // TODO: Refactor and remove this + def addLambdaMethod(method: Ast): Unit = + scopeStack.collectFirst { case typeDeclContainer: TypeDeclContainer => + typeDeclContainer.registerLambdaMethod(method) + } + + // TODO: Refactor and remove this + def capturedVariables: List[ScopeVariable] = + scopeStack + .flatMap(_.getVariables()) + .collect { + case local: ScopeLocal => local + + case parameter: ScopeParameter if parameter.name != NameConstants.This => parameter + } end Scope object Scope: - type NewScopeNode = NewBlock | NewMethod | NewTypeDecl | NewNamespaceBlock - type NewVariableNode = NewLocal | NewMethodParameterIn | NewMember | NewImport - - sealed trait ScopeVariable: - def node: NewVariableNode - def typeFullName: String - def name: String - final case class ScopeLocal(override val node: NewLocal) extends ScopeVariable: - val typeFullName: String = node.typeFullName - val name = node.name - final case class ScopeParameter(override val node: NewMethodParameterIn) extends ScopeVariable: - val typeFullName: String = node.typeFullName - val name = node.name - final case class ScopeMember(override val node: NewMember, isStatic: Boolean) - extends ScopeVariable: - val typeFullName: String = node.typeFullName - val name = node.name - final case class ScopeStaticImport(override val node: NewImport) extends ScopeVariable: - val typeFullName: String = node.importedEntity.get - val name = node.importedAs.get - - sealed trait VariableLookupResult: - def typeFullName: Option[String] = None - def variableNode: Option[NewVariableNode] = None - - // TODO: Added for convenience, but when proper capture logic is implemented the found - // variable cases will have to be handled separately which would render this unnecessary. - def getVariable(): Option[ScopeVariable] = None - - // TODO: Refactor and remove this - def asNodeInfoOption: Option[NodeTypeInfo] = None - case object NotInScope extends VariableLookupResult - sealed trait FoundVariable(variable: ScopeVariable) extends VariableLookupResult: - override val typeFullName: Option[String] = Some(variable.typeFullName) - override val variableNode: Option[NewVariableNode] = Some(variable.node) - override def getVariable(): Option[ScopeVariable] = Some(variable) - - override def asNodeInfoOption: Option[NodeTypeInfo] = - val nodeTypeInfo = variable match - case ScopeMember(memberNode, isStatic) => - NodeTypeInfo( - memberNode, - memberNode.name, - Some(memberNode.typeFullName), - true, - isStatic - ) - - case variable => NodeTypeInfo( - variable.node, - variable.name, - Some(variable.typeFullName), - false, - false - ) - - Some(nodeTypeInfo) - end asNodeInfoOption - end FoundVariable - final case class SimpleVariable(variable: ScopeVariable) extends FoundVariable(variable) - - final case class CapturedVariable(typeDeclChain: List[NewTypeDecl], variable: ScopeVariable) - extends FoundVariable(variable) + type NewScopeNode = NewBlock | NewMethod | NewTypeDecl | NewNamespaceBlock + type NewVariableNode = NewLocal | NewMethodParameterIn | NewMember | NewImport + + sealed trait ScopeVariable: + def node: NewVariableNode + def typeFullName: String + def name: String + final case class ScopeLocal(override val node: NewLocal) extends ScopeVariable: + val typeFullName: String = node.typeFullName + val name = node.name + final case class ScopeParameter(override val node: NewMethodParameterIn) extends ScopeVariable: + val typeFullName: String = node.typeFullName + val name = node.name + final case class ScopeMember(override val node: NewMember, isStatic: Boolean) + extends ScopeVariable: + val typeFullName: String = node.typeFullName + val name = node.name + final case class ScopeStaticImport(override val node: NewImport) extends ScopeVariable: + val typeFullName: String = node.importedEntity.get + val name = node.importedAs.get + + sealed trait VariableLookupResult: + def typeFullName: Option[String] = None + def variableNode: Option[NewVariableNode] = None + + // TODO: Added for convenience, but when proper capture logic is implemented the found + // variable cases will have to be handled separately which would render this unnecessary. + def getVariable(): Option[ScopeVariable] = None + + // TODO: Refactor and remove this + def asNodeInfoOption: Option[NodeTypeInfo] = None + case object NotInScope extends VariableLookupResult + sealed trait FoundVariable(variable: ScopeVariable) extends VariableLookupResult: + override val typeFullName: Option[String] = Some(variable.typeFullName) + override val variableNode: Option[NewVariableNode] = Some(variable.node) + override def getVariable(): Option[ScopeVariable] = Some(variable) + + override def asNodeInfoOption: Option[NodeTypeInfo] = + val nodeTypeInfo = variable match + case ScopeMember(memberNode, isStatic) => + NodeTypeInfo( + memberNode, + memberNode.name, + Some(memberNode.typeFullName), + true, + isStatic + ) + + case variable => NodeTypeInfo( + variable.node, + variable.name, + Some(variable.typeFullName), + false, + false + ) + + Some(nodeTypeInfo) + end asNodeInfoOption + end FoundVariable + final case class SimpleVariable(variable: ScopeVariable) extends FoundVariable(variable) + + final case class CapturedVariable(typeDeclChain: List[NewTypeDecl], variable: ScopeVariable) + extends FoundVariable(variable) end Scope diff --git a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/scope/TypeDeclContainer.scala b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/scope/TypeDeclContainer.scala index 23a3341c..68fba55c 100644 --- a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/scope/TypeDeclContainer.scala +++ b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/scope/TypeDeclContainer.scala @@ -5,16 +5,16 @@ import io.appthreat.x2cpg.Ast import scala.collection.mutable trait TypeDeclContainer: - private val typeDeclsToAdd = mutable.ListBuffer[Ast]() - private val lambdaMethods = mutable.ListBuffer[Ast]() + private val typeDeclsToAdd = mutable.ListBuffer[Ast]() + private val lambdaMethods = mutable.ListBuffer[Ast]() - def registerTypeDecl(typeDecl: Ast) = - typeDeclsToAdd.append(typeDecl) + def registerTypeDecl(typeDecl: Ast) = + typeDeclsToAdd.append(typeDecl) - def registeredTypeDecls: List[Ast] = typeDeclsToAdd.toList + def registeredTypeDecls: List[Ast] = typeDeclsToAdd.toList - // TODO: Refactor and remove this - def registerLambdaMethod(lambdaMethod: Ast) = - lambdaMethods.append(lambdaMethod) + // TODO: Refactor and remove this + def registerLambdaMethod(lambdaMethod: Ast) = + lambdaMethods.append(lambdaMethod) - def registeredLambdaMethods: List[Ast] = lambdaMethods.toList + def registeredLambdaMethods: List[Ast] = lambdaMethods.toList diff --git a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/typesolvers/EagerSourceTypeSolver.scala b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/typesolvers/EagerSourceTypeSolver.scala index 0a41602e..ca04f507 100644 --- a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/typesolvers/EagerSourceTypeSolver.scala +++ b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/typesolvers/EagerSourceTypeSolver.scala @@ -20,60 +20,60 @@ class EagerSourceTypeSolver( symbolSolver: JavaSymbolSolver ) extends TypeSolver: - private val logger = LoggerFactory.getLogger(this.getClass) - private var parent: TypeSolver = scala.compiletime.uninitialized + private val logger = LoggerFactory.getLogger(this.getClass) + private var parent: TypeSolver = scala.compiletime.uninitialized - private val foundTypes: Map[String, SymbolReference[ResolvedReferenceTypeDeclaration]] = - filenames - .flatMap(sourceParser.parseTypesFile) - .flatMap { cu => - symbolSolver.inject(cu) - cu.findAll(classOf[TypeDeclaration[?]]) - .asScala - .map { typeDeclaration => - val name = typeDeclaration.getFullyQualifiedName.toScala match - case Some(fullyQualifiedName) => fullyQualifiedName - case None => - val name = typeDeclaration.getNameAsString - logger.debug( - s"Could not find fully qualified name for typeDecl $name" - ) - name - TypeSizeReducer.simplifyType(typeDeclaration) - val resolvedSymbol = Try( - SymbolReference.solved( - JavaParserFacade.get(combinedTypeSolver).getTypeDeclaration( - typeDeclaration - ) - ): SymbolReference[ResolvedReferenceTypeDeclaration] - ).getOrElse(SymbolReference.unsolved()) - name -> resolvedSymbol - } - .toList - } - .toMap + private val foundTypes: Map[String, SymbolReference[ResolvedReferenceTypeDeclaration]] = + filenames + .flatMap(sourceParser.parseTypesFile) + .flatMap { cu => + symbolSolver.inject(cu) + cu.findAll(classOf[TypeDeclaration[?]]) + .asScala + .map { typeDeclaration => + val name = typeDeclaration.getFullyQualifiedName.toScala match + case Some(fullyQualifiedName) => fullyQualifiedName + case None => + val name = typeDeclaration.getNameAsString + logger.debug( + s"Could not find fully qualified name for typeDecl $name" + ) + name + TypeSizeReducer.simplifyType(typeDeclaration) + val resolvedSymbol = Try( + SymbolReference.solved( + JavaParserFacade.get(combinedTypeSolver).getTypeDeclaration( + typeDeclaration + ) + ): SymbolReference[ResolvedReferenceTypeDeclaration] + ).getOrElse(SymbolReference.unsolved()) + name -> resolvedSymbol + } + .toList + } + .toMap - override def getParent: TypeSolver = parent + override def getParent: TypeSolver = parent - override def setParent(parent: TypeSolver): Unit = - if parent == null then - logger.debug(s"Cannot set parent of type solver to null. setParent will be ignored.") - else if this.parent != null then - logger.debug(s"Attempting to re-set type solver parent. setParent will be ignored.") - else if parent == this then - logger.debug(s"Parent of TypeSolver cannot be itself. setParent will be ignored.") - else - this.parent = parent + override def setParent(parent: TypeSolver): Unit = + if parent == null then + logger.debug(s"Cannot set parent of type solver to null. setParent will be ignored.") + else if this.parent != null then + logger.debug(s"Attempting to re-set type solver parent. setParent will be ignored.") + else if parent == this then + logger.debug(s"Parent of TypeSolver cannot be itself. setParent will be ignored.") + else + this.parent = parent - override def tryToSolveType(name: String): SymbolReference[ResolvedReferenceTypeDeclaration] = - foundTypes.getOrElse(name, SymbolReference.unsolved()) + override def tryToSolveType(name: String): SymbolReference[ResolvedReferenceTypeDeclaration] = + foundTypes.getOrElse(name, SymbolReference.unsolved()) end EagerSourceTypeSolver object EagerSourceTypeSolver: - def apply( - filenames: Array[String], - sourceParser: SourceParser, - combinedTypeSolver: SimpleCombinedTypeSolver, - symbolSolver: JavaSymbolSolver - ): EagerSourceTypeSolver = - new EagerSourceTypeSolver(filenames, sourceParser, combinedTypeSolver, symbolSolver) + def apply( + filenames: Array[String], + sourceParser: SourceParser, + combinedTypeSolver: SimpleCombinedTypeSolver, + symbolSolver: JavaSymbolSolver + ): EagerSourceTypeSolver = + new EagerSourceTypeSolver(filenames, sourceParser, combinedTypeSolver, symbolSolver) diff --git a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/typesolvers/JmodClassPath.scala b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/typesolvers/JmodClassPath.scala index fc790c06..5b54c042 100644 --- a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/typesolvers/JmodClassPath.scala +++ b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/typesolvers/JmodClassPath.scala @@ -11,35 +11,35 @@ import java.net.URL import java.util.jar.{JarEntry, JarFile} class JmodClassPath(jmodPath: String) extends ClassPath: - private val jarfile = new JarFile(jmodPath) - private val jarfileURL = File(jmodPath).url.toString - private val entries = getEntriesMap(jarfile) - - private def entryToClassName(entry: JarEntry): String = - entry.getName.stripPrefix(JmodClassesPrefix).stripSuffix(".class").replace('/', '.') - - private def getEntriesMap(jarfile: JarFile): Map[String, JarEntry] = - jarfile - .entries() - .asScala - .filter(_.getName.startsWith(JmodClassesPrefix)) - .filter(_.getName.endsWith(".class")) - .map { entry => entryToClassName(entry) -> entry } - .toMap - - override def find(classname: String): URL = - val jarname = classname.replace('.', '/') + ".class" - - if entries.contains(classname) then - Try(new URL(s"jmod:${jarfileURL}!/${jarname}")).getOrElse(null) - else null - - override def openClassfile(classname: String): InputStream = - entries.get(classname) match - case None => null - - case Some(entry) => jarfile.getInputStream(entry) + private val jarfile = new JarFile(jmodPath) + private val jarfileURL = File(jmodPath).url.toString + private val entries = getEntriesMap(jarfile) + + private def entryToClassName(entry: JarEntry): String = + entry.getName.stripPrefix(JmodClassesPrefix).stripSuffix(".class").replace('/', '.') + + private def getEntriesMap(jarfile: JarFile): Map[String, JarEntry] = + jarfile + .entries() + .asScala + .filter(_.getName.startsWith(JmodClassesPrefix)) + .filter(_.getName.endsWith(".class")) + .map { entry => entryToClassName(entry) -> entry } + .toMap + + override def find(classname: String): URL = + val jarname = classname.replace('.', '/') + ".class" + + if entries.contains(classname) then + Try(new URL(s"jmod:${jarfileURL}!/${jarname}")).getOrElse(null) + else null + + override def openClassfile(classname: String): InputStream = + entries.get(classname) match + case None => null + + case Some(entry) => jarfile.getInputStream(entry) end JmodClassPath object JmodClassPath: - val JmodClassesPrefix: String = "classes/" + val JmodClassesPrefix: String = "classes/" diff --git a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/typesolvers/NonCachingClassPool.scala b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/typesolvers/NonCachingClassPool.scala index e44ed1d4..1e0e7efd 100644 --- a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/typesolvers/NonCachingClassPool.scala +++ b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/typesolvers/NonCachingClassPool.scala @@ -10,5 +10,5 @@ import scala.annotation.nowarn * the search list. */ class NonCachingClassPool extends ClassPool(false): - @nowarn override def cacheCtClass(className: String, ctClass: CtClass, dynamic: Boolean): Unit = - () + @nowarn override def cacheCtClass(className: String, ctClass: CtClass, dynamic: Boolean): Unit = + () diff --git a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/typesolvers/SimpleCombinedTypeSolver.scala b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/typesolvers/SimpleCombinedTypeSolver.scala index 4938ed71..5a26b275 100644 --- a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/typesolvers/SimpleCombinedTypeSolver.scala +++ b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/typesolvers/SimpleCombinedTypeSolver.scala @@ -13,81 +13,81 @@ import scala.jdk.OptionConverters.RichOptional class SimpleCombinedTypeSolver extends TypeSolver: - private val logger = LoggerFactory.getLogger(this.getClass) - private var parent: TypeSolver = scala.compiletime.uninitialized - // Ideally all types would be cached in the SimpleCombinedTypeSolver to avoid unnecessary unresolved types - // from being cached. The EagerSourceTypeSolver preloads all types, however, so separating caching and - // non-caching solvers avoids caching types twice. - private val cachingTypeSolvers: mutable.ArrayBuffer[TypeSolver] = mutable.ArrayBuffer() - private val nonCachingTypeSolvers: mutable.ArrayBuffer[TypeSolver] = mutable.ArrayBuffer() + private val logger = LoggerFactory.getLogger(this.getClass) + private var parent: TypeSolver = scala.compiletime.uninitialized + // Ideally all types would be cached in the SimpleCombinedTypeSolver to avoid unnecessary unresolved types + // from being cached. The EagerSourceTypeSolver preloads all types, however, so separating caching and + // non-caching solvers avoids caching types twice. + private val cachingTypeSolvers: mutable.ArrayBuffer[TypeSolver] = mutable.ArrayBuffer() + private val nonCachingTypeSolvers: mutable.ArrayBuffer[TypeSolver] = mutable.ArrayBuffer() - private val typeCache = new GuavaCache( - CacheBuilder.newBuilder().build[String, SymbolReference[ResolvedReferenceTypeDeclaration]]() - ) + private val typeCache = new GuavaCache( + CacheBuilder.newBuilder().build[String, SymbolReference[ResolvedReferenceTypeDeclaration]]() + ) - def addCachingTypeSolver(typeSolver: TypeSolver): Unit = - cachingTypeSolvers.append(typeSolver) - typeSolver.setParent(this) + def addCachingTypeSolver(typeSolver: TypeSolver): Unit = + cachingTypeSolvers.append(typeSolver) + typeSolver.setParent(this) - def addNonCachingTypeSolver(typeSolver: TypeSolver): Unit = - nonCachingTypeSolvers.prepend(typeSolver) - typeSolver.setParent(this) + def addNonCachingTypeSolver(typeSolver: TypeSolver): Unit = + nonCachingTypeSolvers.prepend(typeSolver) + typeSolver.setParent(this) - override def tryToSolveType(name: String): SymbolReference[ResolvedReferenceTypeDeclaration] = - typeCache.get(name).toScala match - case Some(result) => result + override def tryToSolveType(name: String): SymbolReference[ResolvedReferenceTypeDeclaration] = + typeCache.get(name).toScala match + case Some(result) => result - case None => - findSolvedTypeWithSolvers(cachingTypeSolvers, name) - .getOrElse { - val result = findSolvedTypeWithSolvers( - nonCachingTypeSolvers, - name - ).getOrElse(SymbolReference.unsolved()) - typeCache.put(name, result) - result - } + case None => + findSolvedTypeWithSolvers(cachingTypeSolvers, name) + .getOrElse { + val result = findSolvedTypeWithSolvers( + nonCachingTypeSolvers, + name + ).getOrElse(SymbolReference.unsolved()) + typeCache.put(name, result) + result + } - private def findSolvedTypeWithSolvers( - typeSolvers: mutable.ArrayBuffer[TypeSolver], - className: String - ): Option[SymbolReference[ResolvedReferenceTypeDeclaration]] = - typeSolvers.iterator - .map { typeSolver => - try - val result = typeSolver.tryToSolveType(className): SymbolReference[ - ResolvedReferenceTypeDeclaration - ] - Option.when(result.isSolved())(result) - catch - case _: UnsolvedSymbolException => None - case _: StackOverflowError => None - case _: IllegalArgumentException => - // RecordDeclarations aren't handled by JavaParser yet - None - case unhandled: Throwable => - None - } - .collectFirst { case Some(symbolReference) => - symbolReference - } + private def findSolvedTypeWithSolvers( + typeSolvers: mutable.ArrayBuffer[TypeSolver], + className: String + ): Option[SymbolReference[ResolvedReferenceTypeDeclaration]] = + typeSolvers.iterator + .map { typeSolver => + try + val result = typeSolver.tryToSolveType(className): SymbolReference[ + ResolvedReferenceTypeDeclaration + ] + Option.when(result.isSolved())(result) + catch + case _: UnsolvedSymbolException => None + case _: StackOverflowError => None + case _: IllegalArgumentException => + // RecordDeclarations aren't handled by JavaParser yet + None + case unhandled: Throwable => + None + } + .collectFirst { case Some(symbolReference) => + symbolReference + } - override def solveType(name: String): ResolvedReferenceTypeDeclaration = - val result = tryToSolveType(name) - if result.isSolved then - result.getCorrespondingDeclaration - else - throw new UnsolvedSymbolException(name) + override def solveType(name: String): ResolvedReferenceTypeDeclaration = + val result = tryToSolveType(name) + if result.isSolved then + result.getCorrespondingDeclaration + else + throw new UnsolvedSymbolException(name) - override def getParent: TypeSolver = parent + override def getParent: TypeSolver = parent - override def setParent(parent: TypeSolver): Unit = - if parent == null then - logger.debug(s"Cannot set parent of type solver to null. setParent will be ignored.") - else if this.parent != null then - logger.debug(s"Attempting to re-set type solver parent. setParent will be ignored.") - else if parent == this then - logger.debug(s"Parent of TypeSolver cannot be itself. setParent will be ignored.") - else - this.parent = parent + override def setParent(parent: TypeSolver): Unit = + if parent == null then + logger.debug(s"Cannot set parent of type solver to null. setParent will be ignored.") + else if this.parent != null then + logger.debug(s"Attempting to re-set type solver parent. setParent will be ignored.") + else if parent == this then + logger.debug(s"Parent of TypeSolver cannot be itself. setParent will be ignored.") + else + this.parent = parent end SimpleCombinedTypeSolver diff --git a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/typesolvers/TypeInfoCalculator.scala b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/typesolvers/TypeInfoCalculator.scala index 91ce7b4e..dda1ef99 100644 --- a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/typesolvers/TypeInfoCalculator.scala +++ b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/typesolvers/TypeInfoCalculator.scala @@ -20,253 +20,253 @@ import scala.jdk.OptionConverters.RichOptional import scala.util.Try class TypeInfoCalculator(global: Global, symbolResolver: SymbolResolver): - private val logger = LoggerFactory.getLogger(this.getClass) - private val emptyTypeParamValues = ResolvedTypeParametersMap.empty() + private val logger = LoggerFactory.getLogger(this.getClass) + private val emptyTypeParamValues = ResolvedTypeParametersMap.empty() - def name(typ: ResolvedType): Option[String] = - nameOrFullName(typ, emptyTypeParamValues, fullyQualified = false) + def name(typ: ResolvedType): Option[String] = + nameOrFullName(typ, emptyTypeParamValues, fullyQualified = false) - def name(typ: ResolvedType, typeParamValues: ResolvedTypeParametersMap): Option[String] = - nameOrFullName(typ, typeParamValues, fullyQualified = false) + def name(typ: ResolvedType, typeParamValues: ResolvedTypeParametersMap): Option[String] = + nameOrFullName(typ, typeParamValues, fullyQualified = false) - def fullName(typ: ResolvedType): Option[String] = - nameOrFullName(typ, emptyTypeParamValues, fullyQualified = true).map(registerType) + def fullName(typ: ResolvedType): Option[String] = + nameOrFullName(typ, emptyTypeParamValues, fullyQualified = true).map(registerType) - def fullName(typ: ResolvedType, typeParamValues: ResolvedTypeParametersMap): Option[String] = - nameOrFullName(typ, typeParamValues, fullyQualified = true).map(registerType) + def fullName(typ: ResolvedType, typeParamValues: ResolvedTypeParametersMap): Option[String] = + nameOrFullName(typ, typeParamValues, fullyQualified = true).map(registerType) - private def typesSubstituted( - trySubstitutedType: Try[ResolvedType], - typeParamDecl: ResolvedTypeParameterDeclaration - ): Boolean = - trySubstitutedType - .map { substitutedType => - // substitutedType.isTypeVariable can crash with an UnsolvedSymbolException if it is an instance of LazyType, - // in which case the type hasn't been successfully substituted. - val substitutionOccurred = - !(substitutedType.isTypeVariable && substitutedType.asTypeParameter() == typeParamDecl) - // There's a potential infinite loop that can occur when a type variable is substituted with a wildcard type - // bounded by that type variable. - val isSimilarWildcardSubstition = substitutedType match - case wc: ResolvedWildcard => - Try(wc.getBoundedType.asTypeParameter()).toOption.contains(typeParamDecl) - case _ => false + private def typesSubstituted( + trySubstitutedType: Try[ResolvedType], + typeParamDecl: ResolvedTypeParameterDeclaration + ): Boolean = + trySubstitutedType + .map { substitutedType => + // substitutedType.isTypeVariable can crash with an UnsolvedSymbolException if it is an instance of LazyType, + // in which case the type hasn't been successfully substituted. + val substitutionOccurred = + !(substitutedType.isTypeVariable && substitutedType.asTypeParameter() == typeParamDecl) + // There's a potential infinite loop that can occur when a type variable is substituted with a wildcard type + // bounded by that type variable. + val isSimilarWildcardSubstition = substitutedType match + case wc: ResolvedWildcard => + Try(wc.getBoundedType.asTypeParameter()).toOption.contains(typeParamDecl) + case _ => false - substitutionOccurred && !isSimilarWildcardSubstition - } - .getOrElse(false) + substitutionOccurred && !isSimilarWildcardSubstition + } + .getOrElse(false) - private def nameOrFullName( - typ: ResolvedType, - typeParamValues: ResolvedTypeParametersMap, - fullyQualified: Boolean - ): Option[String] = - typ match - case refType: ResolvedReferenceType => - nameOrFullName(refType.getTypeDeclaration.get, fullyQualified) - case lazyType: LazyType => - lazyType match - case _ if lazyType.isReferenceType => - nameOrFullName(lazyType.asReferenceType(), typeParamValues, fullyQualified) - case _ if lazyType.isTypeVariable => - nameOrFullName(lazyType.asTypeVariable(), typeParamValues, fullyQualified) - case _ if lazyType.isArray => - nameOrFullName(lazyType.asArrayType(), typeParamValues, fullyQualified) - case _ if lazyType.isPrimitive => - nameOrFullName(lazyType.asPrimitive(), typeParamValues, fullyQualified) - case _ if lazyType.isWildcard => - nameOrFullName(lazyType.asWildcard(), typeParamValues, fullyQualified) - case tpe @ (_: ResolvedVoidType | _: ResolvedPrimitiveType) => - Some(tpe.describe()) - case arrayType: ResolvedArrayType => - nameOrFullName(arrayType.getComponentType, typeParamValues, fullyQualified).map( - _ + "[]" - ) - case nullType: NullType => - Some(nullType.describe()) - case typeVariable: ResolvedTypeVariable => - val typeParamDecl = typeVariable.asTypeParameter() - val substitutedType = Try(typeParamValues.getValue(typeParamDecl)) + private def nameOrFullName( + typ: ResolvedType, + typeParamValues: ResolvedTypeParametersMap, + fullyQualified: Boolean + ): Option[String] = + typ match + case refType: ResolvedReferenceType => + nameOrFullName(refType.getTypeDeclaration.get, fullyQualified) + case lazyType: LazyType => + lazyType match + case _ if lazyType.isReferenceType => + nameOrFullName(lazyType.asReferenceType(), typeParamValues, fullyQualified) + case _ if lazyType.isTypeVariable => + nameOrFullName(lazyType.asTypeVariable(), typeParamValues, fullyQualified) + case _ if lazyType.isArray => + nameOrFullName(lazyType.asArrayType(), typeParamValues, fullyQualified) + case _ if lazyType.isPrimitive => + nameOrFullName(lazyType.asPrimitive(), typeParamValues, fullyQualified) + case _ if lazyType.isWildcard => + nameOrFullName(lazyType.asWildcard(), typeParamValues, fullyQualified) + case tpe @ (_: ResolvedVoidType | _: ResolvedPrimitiveType) => + Some(tpe.describe()) + case arrayType: ResolvedArrayType => + nameOrFullName(arrayType.getComponentType, typeParamValues, fullyQualified).map( + _ + "[]" + ) + case nullType: NullType => + Some(nullType.describe()) + case typeVariable: ResolvedTypeVariable => + val typeParamDecl = typeVariable.asTypeParameter() + val substitutedType = Try(typeParamValues.getValue(typeParamDecl)) - if typesSubstituted(substitutedType, typeParamDecl) then - nameOrFullName(substitutedType.get, typeParamValues, fullyQualified) - else - val extendsBoundOption = - Try(typeParamDecl.getBounds.asScala.find(_.isExtends)).toOption.flatten - extendsBoundOption - .flatMap(bound => - nameOrFullName(bound.getType, typeParamValues, fullyQualified) - ) - .orElse(objectType(fullyQualified)) - case lambdaConstraintType: ResolvedLambdaConstraintType => - nameOrFullName(lambdaConstraintType.getBound, typeParamValues, fullyQualified) - case wildcardType: ResolvedWildcard => - if wildcardType.isBounded then - nameOrFullName(wildcardType.getBoundedType, typeParamValues, fullyQualified) - else - objectType(fullyQualified) - case unionType: ResolvedUnionType => - Try(unionType.getElements.asScala.find(_.isReferenceType)).toOption.flatten - .flatMap(nameOrFullName(_, typeParamValues, fullyQualified)) - .orElse(objectType(fullyQualified)) - case intersectionType: ResolvedIntersectionType => - Try(intersectionType.getElements.asScala.find(_.isReferenceType)).toOption.flatten - .flatMap(nameOrFullName(_, typeParamValues, fullyQualified)) - .orElse(objectType(fullyQualified)) - case _: InferenceVariableType => - // From the JavaParser docs, the InferenceVariableType is: An element using during type inference. - // At this point JavaParser has failed to resolve the type. - None + if typesSubstituted(substitutedType, typeParamDecl) then + nameOrFullName(substitutedType.get, typeParamValues, fullyQualified) + else + val extendsBoundOption = + Try(typeParamDecl.getBounds.asScala.find(_.isExtends)).toOption.flatten + extendsBoundOption + .flatMap(bound => + nameOrFullName(bound.getType, typeParamValues, fullyQualified) + ) + .orElse(objectType(fullyQualified)) + case lambdaConstraintType: ResolvedLambdaConstraintType => + nameOrFullName(lambdaConstraintType.getBound, typeParamValues, fullyQualified) + case wildcardType: ResolvedWildcard => + if wildcardType.isBounded then + nameOrFullName(wildcardType.getBoundedType, typeParamValues, fullyQualified) + else + objectType(fullyQualified) + case unionType: ResolvedUnionType => + Try(unionType.getElements.asScala.find(_.isReferenceType)).toOption.flatten + .flatMap(nameOrFullName(_, typeParamValues, fullyQualified)) + .orElse(objectType(fullyQualified)) + case intersectionType: ResolvedIntersectionType => + Try(intersectionType.getElements.asScala.find(_.isReferenceType)).toOption.flatten + .flatMap(nameOrFullName(_, typeParamValues, fullyQualified)) + .orElse(objectType(fullyQualified)) + case _: InferenceVariableType => + // From the JavaParser docs, the InferenceVariableType is: An element using during type inference. + // At this point JavaParser has failed to resolve the type. + None - private def objectType(fullyQualified: Boolean): Option[String] = - // Return an option type for - if fullyQualified then - Some(TypeConstants.Object) - else - Some(TypeNameConstants.Object) + private def objectType(fullyQualified: Boolean): Option[String] = + // Return an option type for + if fullyQualified then + Some(TypeConstants.Object) + else + Some(TypeNameConstants.Object) - def name(typ: Type): Option[String] = - nameOrFullName(typ, fullyQualified = false) + def name(typ: Type): Option[String] = + nameOrFullName(typ, fullyQualified = false) - def fullName(typ: Type): Option[String] = - nameOrFullName(typ, fullyQualified = true).map(registerType) + def fullName(typ: Type): Option[String] = + nameOrFullName(typ, fullyQualified = true).map(registerType) - private def nameOrFullName(typ: Type, fullyQualified: Boolean): Option[String] = - typ match - case primitiveType: PrimitiveType => - Some(primitiveType.toString) - case _ => - // We are using symbolResolver.toResolvedType() instead of typ.resolve() because - // the resolve() is just a wrapper for a call to symbolResolver.toResolvedType() - // with a specific class given as argument to which the result is casted to. - // It appears to be that ClassOrInterfaceType.resolve() is using a too restrictive - // bound (ResolvedReferenceType.class) which invalidates an otherwise successful - // resolve. Since we anyway dont care about the type cast, we directly access the - // symbolResolver and specifiy the most generic type ResolvedType. - Try(symbolResolver.toResolvedType(typ, classOf[ResolvedType])).toOption - .flatMap(resolvedType => - nameOrFullName(resolvedType, emptyTypeParamValues, fullyQualified) - ) + private def nameOrFullName(typ: Type, fullyQualified: Boolean): Option[String] = + typ match + case primitiveType: PrimitiveType => + Some(primitiveType.toString) + case _ => + // We are using symbolResolver.toResolvedType() instead of typ.resolve() because + // the resolve() is just a wrapper for a call to symbolResolver.toResolvedType() + // with a specific class given as argument to which the result is casted to. + // It appears to be that ClassOrInterfaceType.resolve() is using a too restrictive + // bound (ResolvedReferenceType.class) which invalidates an otherwise successful + // resolve. Since we anyway dont care about the type cast, we directly access the + // symbolResolver and specifiy the most generic type ResolvedType. + Try(symbolResolver.toResolvedType(typ, classOf[ResolvedType])).toOption + .flatMap(resolvedType => + nameOrFullName(resolvedType, emptyTypeParamValues, fullyQualified) + ) - def name(decl: ResolvedDeclaration): Option[String] = - nameOrFullName(decl, fullyQualified = false) + def name(decl: ResolvedDeclaration): Option[String] = + nameOrFullName(decl, fullyQualified = false) - def fullName(decl: ResolvedDeclaration): Option[String] = - nameOrFullName(decl, fullyQualified = true).map(registerType) + def fullName(decl: ResolvedDeclaration): Option[String] = + nameOrFullName(decl, fullyQualified = true).map(registerType) - private def nameOrFullName(decl: ResolvedDeclaration, fullyQualified: Boolean): Option[String] = - decl match - case typeDecl: ResolvedTypeDeclaration => - nameOrFullName(typeDecl, fullyQualified) + private def nameOrFullName(decl: ResolvedDeclaration, fullyQualified: Boolean): Option[String] = + decl match + case typeDecl: ResolvedTypeDeclaration => + nameOrFullName(typeDecl, fullyQualified) - private def nameOrFullName( - typeDecl: ResolvedTypeDeclaration, - fullyQualified: Boolean - ): Option[String] = - typeDecl match - case typeParamDecl: ResolvedTypeParameterDeclaration => - if fullyQualified then - val containFullName = - nameOrFullName( - typeParamDecl.getContainer.asInstanceOf[ResolvedDeclaration], - fullyQualified = true - ) - containFullName.map(_ + "." + typeParamDecl.getName) - else - Some(typeParamDecl.getName) - case _ => - val typeName = Option(typeDecl.getName).getOrElse( - throw new RuntimeException("TODO Investigate") - ) + private def nameOrFullName( + typeDecl: ResolvedTypeDeclaration, + fullyQualified: Boolean + ): Option[String] = + typeDecl match + case typeParamDecl: ResolvedTypeParameterDeclaration => + if fullyQualified then + val containFullName = + nameOrFullName( + typeParamDecl.getContainer.asInstanceOf[ResolvedDeclaration], + fullyQualified = true + ) + containFullName.map(_ + "." + typeParamDecl.getName) + else + Some(typeParamDecl.getName) + case _ => + val typeName = Option(typeDecl.getName).getOrElse( + throw new RuntimeException("TODO Investigate") + ) - // TODO Sadly we need to use a try here in order to catch the exception emitted by - // the javaparser library instead of just returning an empty option. - // In almost all cases we get here the exception is thrown. Check impact on performance - // and hopefully find a better solution if necessary. - val isInnerTypeDecl = Try(typeDecl.containerType().isPresent).getOrElse(false) - if isInnerTypeDecl then - nameOrFullName(typeDecl.containerType().get, fullyQualified).map( - _ + "$" + typeName - ) - else if fullyQualified then - val packageName = typeDecl.getPackageName + // TODO Sadly we need to use a try here in order to catch the exception emitted by + // the javaparser library instead of just returning an empty option. + // In almost all cases we get here the exception is thrown. Check impact on performance + // and hopefully find a better solution if necessary. + val isInnerTypeDecl = Try(typeDecl.containerType().isPresent).getOrElse(false) + if isInnerTypeDecl then + nameOrFullName(typeDecl.containerType().get, fullyQualified).map( + _ + "$" + typeName + ) + else if fullyQualified then + val packageName = typeDecl.getPackageName - if packageName == null || packageName == "" then - Some(typeName) - else - Some(packageName + "." + typeName) - else - Some(typeName) + if packageName == null || packageName == "" then + Some(typeName) + else + Some(packageName + "." + typeName) + else + Some(typeName) - /** Add `typeName` to a global map and return it. The map is later passed to a pass that creates - * TYPE nodes for each key in the map. Skip the `ANY` type, since this is created by default. - * TODO: I want the type registration not in here but for now it is the easiest. - */ - def registerType(typeName: String): String = - if typeName != "ANY" then - global.usedTypes.putIfAbsent(typeName, true) - typeName + /** Add `typeName` to a global map and return it. The map is later passed to a pass that creates + * TYPE nodes for each key in the map. Skip the `ANY` type, since this is created by default. + * TODO: I want the type registration not in here but for now it is the easiest. + */ + def registerType(typeName: String): String = + if typeName != "ANY" then + global.usedTypes.putIfAbsent(typeName, true) + typeName end TypeInfoCalculator object TypeInfoCalculator: - def isAutocastType(typeName: String): Boolean = - NumericTypes.contains(typeName) + def isAutocastType(typeName: String): Boolean = + NumericTypes.contains(typeName) - object TypeConstants: - val Byte: String = "byte" - val Short: String = "short" - val Int: String = "int" - val Long: String = "long" - val Float: String = "float" - val Double: String = "double" - val Char: String = "char" - val Boolean: String = "boolean" - val Object: String = "java.lang.Object" - val Class: String = "java.lang.Class" - val Iterator: String = "java.util.Iterator" - val Void: String = "void" - val Any: String = "ANY" + object TypeConstants: + val Byte: String = "byte" + val Short: String = "short" + val Int: String = "int" + val Long: String = "long" + val Float: String = "float" + val Double: String = "double" + val Char: String = "char" + val Boolean: String = "boolean" + val Object: String = "java.lang.Object" + val Class: String = "java.lang.Class" + val Iterator: String = "java.util.Iterator" + val Void: String = "void" + val Any: String = "ANY" - object TypeNameConstants: - val Object: String = "Object" + object TypeNameConstants: + val Object: String = "Object" - // The method signatures for all methods implemented by java.lang.Object, as returned by JavaParser. This is used - // to filter out Object methods when determining which functional interface method a lambda implements. See - // https://docs.oracle.com/javase/8/docs/api/java/lang/FunctionalInterface.html for more details. - val ObjectMethodSignatures: Set[String] = Set( - "wait(long, int)", - "equals(java.lang.Object)", - "clone()", - "toString()", - "wait()", - "hashCode()", - "getClass()", - "notify()", - "finalize()", - "wait(long)", - "notifyAll()", - "registerNatives()" - ) + // The method signatures for all methods implemented by java.lang.Object, as returned by JavaParser. This is used + // to filter out Object methods when determining which functional interface method a lambda implements. See + // https://docs.oracle.com/javase/8/docs/api/java/lang/FunctionalInterface.html for more details. + val ObjectMethodSignatures: Set[String] = Set( + "wait(long, int)", + "equals(java.lang.Object)", + "clone()", + "toString()", + "wait()", + "hashCode()", + "getClass()", + "notify()", + "finalize()", + "wait(long)", + "notifyAll()", + "registerNatives()" + ) - val NumericTypes: Set[String] = Set( - "byte", - "short", - "int", - "long", - "float", - "double", - "char", - "boolean", - "java.lang.Byte", - "java.lang.Short", - "java.lang.Integer", - "java.lang.Long", - "java.lang.Float", - "java.lang.Double", - "java.lang.Character", - "java.lang.Boolean" - ) + val NumericTypes: Set[String] = Set( + "byte", + "short", + "int", + "long", + "float", + "double", + "char", + "boolean", + "java.lang.Byte", + "java.lang.Short", + "java.lang.Integer", + "java.lang.Long", + "java.lang.Float", + "java.lang.Double", + "java.lang.Character", + "java.lang.Boolean" + ) - def apply(global: Global, symbolResolver: SymbolResolver): TypeInfoCalculator = - new TypeInfoCalculator(global, symbolResolver) + def apply(global: Global, symbolResolver: SymbolResolver): TypeInfoCalculator = + new TypeInfoCalculator(global, symbolResolver) end TypeInfoCalculator diff --git a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/typesolvers/TypeSizeReducer.scala b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/typesolvers/TypeSizeReducer.scala index 8d1c0bc7..8cad373f 100644 --- a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/typesolvers/TypeSizeReducer.scala +++ b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/typesolvers/TypeSizeReducer.scala @@ -6,11 +6,11 @@ import com.github.javaparser.ast.stmt.BlockStmt import scala.jdk.CollectionConverters.* object TypeSizeReducer: - def simplifyType(typeDeclaration: TypeDeclaration[?]): Unit = - typeDeclaration - .getMethods() - .asScala - .filter(method => method.getBody().isPresent()) - .foreach { method => - method.setBody(new BlockStmt()) - } + def simplifyType(typeDeclaration: TypeDeclaration[?]): Unit = + typeDeclaration + .getMethods() + .asScala + .filter(method => method.getBody().isPresent()) + .foreach { method => + method.setBody(new BlockStmt()) + } diff --git a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/typesolvers/noncaching/JdkJarTypeSolver.scala b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/typesolvers/noncaching/JdkJarTypeSolver.scala index 01e75ce1..f2acc92a 100644 --- a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/typesolvers/noncaching/JdkJarTypeSolver.scala +++ b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/typesolvers/noncaching/JdkJarTypeSolver.scala @@ -18,168 +18,168 @@ import scala.util.{Failure, Success, Try, Using} class JdkJarTypeSolver extends TypeSolver: - private val logger = LoggerFactory.getLogger(this.getClass) + private val logger = LoggerFactory.getLogger(this.getClass) - private var parent: Option[TypeSolver] = None - private val classPool = new NonCachingClassPool() + private var parent: Option[TypeSolver] = None + private val classPool = new NonCachingClassPool() - val knownPackagePrefixes: mutable.Set[String] = mutable.Set.empty + val knownPackagePrefixes: mutable.Set[String] = mutable.Set.empty - // Populating this causes memory leaks - val packagesJarMappings: mutable.Map[String, mutable.Set[String]] = mutable.Map.empty + // Populating this causes memory leaks + val packagesJarMappings: mutable.Map[String, mutable.Set[String]] = mutable.Map.empty - private type RefType = ResolvedReferenceTypeDeclaration + private type RefType = ResolvedReferenceTypeDeclaration - override def getParent(): TypeSolver = parent.get + override def getParent(): TypeSolver = parent.get - override def setParent(parent: TypeSolver): Unit = - this.parent match - case None => - this.parent = Some(parent) + override def setParent(parent: TypeSolver): Unit = + this.parent match + case None => + this.parent = Some(parent) - case Some(_) => - throw new RuntimeException("JdkJarTypeSolver parent may only be set once") + case Some(_) => + throw new RuntimeException("JdkJarTypeSolver parent may only be set once") - override def tryToSolveType(javaParserName: String) - : SymbolReference[ResolvedReferenceTypeDeclaration] = - val packagePrefix = packagePrefixForJavaParserName(javaParserName) - if knownPackagePrefixes.contains(packagePrefix) then - lookupType(javaParserName) - else - SymbolReference.unsolved() + override def tryToSolveType(javaParserName: String) + : SymbolReference[ResolvedReferenceTypeDeclaration] = + val packagePrefix = packagePrefixForJavaParserName(javaParserName) + if knownPackagePrefixes.contains(packagePrefix) then + lookupType(javaParserName) + else + SymbolReference.unsolved() - private def lookupType(javaParserName: String) - : SymbolReference[ResolvedReferenceTypeDeclaration] = - val name = convertJavaParserNameToStandard(javaParserName) - Try(classPool.get(name)) match - case Success(ctClass) => - val refType = ctClassToRefType(ctClass) - refTypeToSymbolReference(refType) + private def lookupType(javaParserName: String) + : SymbolReference[ResolvedReferenceTypeDeclaration] = + val name = convertJavaParserNameToStandard(javaParserName) + Try(classPool.get(name)) match + case Success(ctClass) => + val refType = ctClassToRefType(ctClass) + refTypeToSymbolReference(refType) + case Failure(e) => + SymbolReference.unsolved() + + override def solveType(name: String): ResolvedReferenceTypeDeclaration = + tryToSolveType(name) match + case symbolReference if symbolReference.isSolved => + symbolReference.getCorrespondingDeclaration + + case _ => throw new UnsolvedSymbolException(name) + + private def ctClassToRefType(ctClass: CtClass): RefType = + JavassistFactory.toTypeDeclaration(ctClass, getRoot) + + private def refTypeToSymbolReference(refType: RefType): SymbolReference[RefType] = + SymbolReference.solved[RefType, RefType](refType) + + private def addPathToClassPool(archivePath: String): Try[ClassPath] = + if archivePath.isJarPath then + Try(classPool.appendClassPath(archivePath)) + else if archivePath.isJmodPath then + val classPath = new JmodClassPath(archivePath) + Try(classPool.appendClassPath(classPath)) + else + Failure(new IllegalArgumentException("$archivePath is not a path to a jar/jmod")) + + def withJars(archivePaths: Seq[String]): JdkJarTypeSolver = + addArchives(archivePaths) + this + + def addArchives(archivePaths: Seq[String]): Unit = + archivePaths.foreach { archivePath => + addPathToClassPool(archivePath) match + case Success(_) => registerPackagesForJar(archivePath) case Failure(e) => - SymbolReference.unsolved() - - override def solveType(name: String): ResolvedReferenceTypeDeclaration = - tryToSolveType(name) match - case symbolReference if symbolReference.isSolved => - symbolReference.getCorrespondingDeclaration - - case _ => throw new UnsolvedSymbolException(name) - - private def ctClassToRefType(ctClass: CtClass): RefType = - JavassistFactory.toTypeDeclaration(ctClass, getRoot) - - private def refTypeToSymbolReference(refType: RefType): SymbolReference[RefType] = - SymbolReference.solved[RefType, RefType](refType) - - private def addPathToClassPool(archivePath: String): Try[ClassPath] = - if archivePath.isJarPath then - Try(classPool.appendClassPath(archivePath)) - else if archivePath.isJmodPath then - val classPath = new JmodClassPath(archivePath) - Try(classPool.appendClassPath(classPath)) - else - Failure(new IllegalArgumentException("$archivePath is not a path to a jar/jmod")) - - def withJars(archivePaths: Seq[String]): JdkJarTypeSolver = - addArchives(archivePaths) - this - - def addArchives(archivePaths: Seq[String]): Unit = - archivePaths.foreach { archivePath => - addPathToClassPool(archivePath) match - case Success(_) => registerPackagesForJar(archivePath) - case Failure(e) => - } - - private def registerPackagesForJar(archivePath: String): Unit = - val entryNameConverter = - if archivePath.isJarPath then packagePrefixForJarEntry else packagePrefixForJmodEntry - try - Using(new JarFile(archivePath)) { jarFile => - def jarPackages = jarFile - .entries() - .asIterator() - .asScala - .filter(entry => - !entry.isDirectory && !entry.getName - .startsWith("module-info") && (entry.getName.endsWith( - ClassExtension - ) || entry.getName - .endsWith(JavaExtension) || entry.getName.endsWith(KtExtension)) - ) - knownPackagePrefixes ++= jarPackages.map(entry => entryNameConverter(entry.getName)) - } - catch - case ioException: IOException => - end registerPackagesForJar + } + + private def registerPackagesForJar(archivePath: String): Unit = + val entryNameConverter = + if archivePath.isJarPath then packagePrefixForJarEntry else packagePrefixForJmodEntry + try + Using(new JarFile(archivePath)) { jarFile => + def jarPackages = jarFile + .entries() + .asIterator() + .asScala + .filter(entry => + !entry.isDirectory && !entry.getName + .startsWith("module-info") && (entry.getName.endsWith( + ClassExtension + ) || entry.getName + .endsWith(JavaExtension) || entry.getName.endsWith(KtExtension)) + ) + knownPackagePrefixes ++= jarPackages.map(entry => entryNameConverter(entry.getName)) + } + catch + case ioException: IOException => + end registerPackagesForJar end JdkJarTypeSolver object JdkJarTypeSolver: - val ClassExtension: String = ".class" - val JavaExtension: String = ".java" - val KtExtension: String = ".kt" - val JmodClassPrefix: String = "classes/" - val JarExtension: String = ".jar" - val JmodExtension: String = ".jmod" - - extension (path: String) - def isJarPath: Boolean = path.endsWith(JarExtension) - def isJmodPath: Boolean = path.endsWith(JmodExtension) - - def fromJdkPath(jdkPath: String): JdkJarTypeSolver = - val jarPaths = SourceFiles.determine(jdkPath, Set(JarExtension, JmodExtension)) - if jarPaths.nonEmpty then new JdkJarTypeSolver().withJars(jarPaths) - else new JdkJarTypeSolver() - - /** Convert JavaParser class name foo.bar.qux.Baz to package prefix foo.bar Only use first 2 - * parts since this is sufficient to deterimine whether a class has been registered in most - * cases and, if not, the failure is just a slow lookup. - */ - def packagePrefixForJavaParserName(className: String): String = - className.split("\\.").take(2).mkString(".") - - /** Convert Jar entry name foo/bar/qux/Baz.class to package prefix foo.bar Only use first 2 - * parts since this is sufficient to deterimine whether a class has been registered in most - * cases and, if not, the failure is just a slow lookup. - */ - def packagePrefixForJarEntry(entryName: String): String = - entryName.split("/").take(2).mkString(".") - - /** Convert jmod entry name classes/foo/bar/qux/Baz.class to package prefix foo.bar Only use - * first 2 parts since this is sufficient to deterimine whether a class has been registered in - * most cases and, if not, the failure is just a slow lookup. - */ - def packagePrefixForJmodEntry(entryName: String): String = - packagePrefixForJarEntry(entryName.stripPrefix(JmodClassPrefix)) - - /** A name is assumed to contain at least one subclass (e.g. ...Foo$Bar) if the last name part - * starts with a digit, or if the last 2 name parts start with capital letters. This heuristic - * is based on the class name format in the JDK jars, where names with subclasses have one of - * the forms: - * - java.lang.ClassLoader$2 - * - java.lang.ClassLoader$NativeLibrary - * - java.lang.ClassLoader$NativeLibrary$Unloader - */ - private def namePartsContainSubclass(nameParts: Array[String]): Boolean = - nameParts.takeRight(2) match - case Array() => false - - case Array(singlePart) => false - - case Array(secondLast, last) => - last.head.isDigit || (secondLast.head.isUpper && last.head.isUpper) - - /** JavaParser replaces the `$` in nested class names with a `.`. This method converts the - * JavaParser names to the standard format by replacing the `.` between name parts that start - * with a capital letter or a digit with a `$` since the jdk classes follow the standard - * practice of capitalising the first letter in class names but not package names. - */ - def convertJavaParserNameToStandard(className: String): String = - className.split(".") match - case nameParts if namePartsContainSubclass(nameParts) => - val (packagePrefix, classNames) = nameParts.partition(_.head.isLower) - s"${packagePrefix.mkString(".")}.${classNames.mkString("$")}" - - case _ => className + val ClassExtension: String = ".class" + val JavaExtension: String = ".java" + val KtExtension: String = ".kt" + val JmodClassPrefix: String = "classes/" + val JarExtension: String = ".jar" + val JmodExtension: String = ".jmod" + + extension (path: String) + def isJarPath: Boolean = path.endsWith(JarExtension) + def isJmodPath: Boolean = path.endsWith(JmodExtension) + + def fromJdkPath(jdkPath: String): JdkJarTypeSolver = + val jarPaths = SourceFiles.determine(jdkPath, Set(JarExtension, JmodExtension)) + if jarPaths.nonEmpty then new JdkJarTypeSolver().withJars(jarPaths) + else new JdkJarTypeSolver() + + /** Convert JavaParser class name foo.bar.qux.Baz to package prefix foo.bar Only use first 2 parts + * since this is sufficient to deterimine whether a class has been registered in most cases and, + * if not, the failure is just a slow lookup. + */ + def packagePrefixForJavaParserName(className: String): String = + className.split("\\.").take(2).mkString(".") + + /** Convert Jar entry name foo/bar/qux/Baz.class to package prefix foo.bar Only use first 2 parts + * since this is sufficient to deterimine whether a class has been registered in most cases and, + * if not, the failure is just a slow lookup. + */ + def packagePrefixForJarEntry(entryName: String): String = + entryName.split("/").take(2).mkString(".") + + /** Convert jmod entry name classes/foo/bar/qux/Baz.class to package prefix foo.bar Only use first + * 2 parts since this is sufficient to deterimine whether a class has been registered in most + * cases and, if not, the failure is just a slow lookup. + */ + def packagePrefixForJmodEntry(entryName: String): String = + packagePrefixForJarEntry(entryName.stripPrefix(JmodClassPrefix)) + + /** A name is assumed to contain at least one subclass (e.g. ...Foo$Bar) if the last name part + * starts with a digit, or if the last 2 name parts start with capital letters. This heuristic is + * based on the class name format in the JDK jars, where names with subclasses have one of the + * forms: + * - java.lang.ClassLoader$2 + * - java.lang.ClassLoader$NativeLibrary + * - java.lang.ClassLoader$NativeLibrary$Unloader + */ + private def namePartsContainSubclass(nameParts: Array[String]): Boolean = + nameParts.takeRight(2) match + case Array() => false + + case Array(singlePart) => false + + case Array(secondLast, last) => + last.head.isDigit || (secondLast.head.isUpper && last.head.isUpper) + + /** JavaParser replaces the `$` in nested class names with a `.`. This method converts the + * JavaParser names to the standard format by replacing the `.` between name parts that start + * with a capital letter or a digit with a `$` since the jdk classes follow the standard practice + * of capitalising the first letter in class names but not package names. + */ + def convertJavaParserNameToStandard(className: String): String = + className.split(".") match + case nameParts if namePartsContainSubclass(nameParts) => + val (packagePrefix, classNames) = nameParts.partition(_.head.isLower) + s"${packagePrefix.mkString(".")}.${classNames.mkString("$")}" + + case _ => className end JdkJarTypeSolver diff --git a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/util/BindingTable.scala b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/util/BindingTable.scala index 3e88a083..925690bf 100644 --- a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/util/BindingTable.scala +++ b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/util/BindingTable.scala @@ -5,98 +5,98 @@ import scala.collection.mutable case class BindingTableEntry(name: String, signature: String, implementingMethodFullName: String) class BindingTable(): - private val entries = mutable.Map.empty[String, BindingTableEntry] + private val entries = mutable.Map.empty[String, BindingTableEntry] - def add(entry: BindingTableEntry): Unit = - entries.put(entry.name + entry.signature, entry) + def add(entry: BindingTableEntry): Unit = + entries.put(entry.name + entry.signature, entry) - def getEntries: Iterable[BindingTableEntry] = - entries.values + def getEntries: Iterable[BindingTableEntry] = + entries.values trait BindingTableAdapter[InputTypeDecl, AstTypeDecl, AstMethodDecl, TypeMap]: - def directParents(typeDecl: InputTypeDecl): collection.Seq[AstTypeDecl] + def directParents(typeDecl: InputTypeDecl): collection.Seq[AstTypeDecl] - def allParentsWithTypeMap(typeDecl: InputTypeDecl): collection.Seq[(AstTypeDecl, TypeMap)] + def allParentsWithTypeMap(typeDecl: InputTypeDecl): collection.Seq[(AstTypeDecl, TypeMap)] - def directBindingTableEntries( - typeDeclFullName: String, - typeDecl: InputTypeDecl - ): collection.Seq[BindingTableEntry] + def directBindingTableEntries( + typeDeclFullName: String, + typeDecl: InputTypeDecl + ): collection.Seq[BindingTableEntry] - def getDeclaredMethods(typeDecl: AstTypeDecl): Iterable[(String, AstMethodDecl)] + def getDeclaredMethods(typeDecl: AstTypeDecl): Iterable[(String, AstMethodDecl)] - def getMethodSignature(methodDecl: AstMethodDecl, typeMap: TypeMap): String + def getMethodSignature(methodDecl: AstMethodDecl, typeMap: TypeMap): String - def getMethodSignatureForEmptyTypeMap(methodDecl: AstMethodDecl): String + def getMethodSignatureForEmptyTypeMap(methodDecl: AstMethodDecl): String - def typeDeclEquals(astTypeDecl: AstTypeDecl, inputTypeDecl: InputTypeDecl): Boolean + def typeDeclEquals(astTypeDecl: AstTypeDecl, inputTypeDecl: InputTypeDecl): Boolean object BindingTable: - def createBindingTable[InputTypeDecl, AstTypeDecl, AstMethodDecl, TypeMap]( - typeDeclFullName: String, - typeDecl: InputTypeDecl, - getBindingTable: AstTypeDecl => BindingTable, - adapter: BindingTableAdapter[InputTypeDecl, AstTypeDecl, AstMethodDecl, TypeMap] - ): BindingTable = - val bindingTable = new BindingTable() - - // Take over all binding table entries for parent class/interface binding tables. - adapter.directParents(typeDecl).filterNot(adapter.typeDeclEquals(_, typeDecl)).foreach { - parentTypeDecl => - val parentBindingTable = - try - getBindingTable(parentTypeDecl) - catch - case e: StackOverflowError => - throw new RuntimeException( - s"SOE getting binding table for $typeDeclFullName" - ) - parentBindingTable.getEntries.foreach { entry => - bindingTable.add(entry) - } - } - - // Create table entries for all methods declared in type declaration. - val directTableEntries = adapter.directBindingTableEntries(typeDeclFullName, typeDecl) - - // Add all table entries for method of type declaration to binding table. - // It is important that this happens after adding the inherited entries - // because later entries for the same slot (same name and signature) - // override previously added entries. - directTableEntries.foreach(bindingTable.add) - - // Override the bindings for generic base class methods if they are overriden. - // To do so we need to traverse all methods in all parent type and calculate - // their signature in the derived type declarations context, meaning with the - // concrete values for the generic type parameters. If this signature together - // with the name matches a direct table entry we have an override and replace - // the binding table entry for the erased! parent method signature. - // This become necessary because calls in the JVM executed via erased signatures. - adapter.allParentsWithTypeMap(typeDecl).foreach { - case (parentTypeDecl, typeParameterInDerivedContext) => - directTableEntries.foreach { directTableEntry => - val parentMethods = adapter.getDeclaredMethods(parentTypeDecl) - parentMethods.foreach { case (parentName, parentMethodDecl) => - if directTableEntry.name == parentName then - val parentSigInDerivedContext = adapter.getMethodSignature( - parentMethodDecl, - typeParameterInDerivedContext - ) - if directTableEntry.signature == parentSigInDerivedContext then - val erasedParentMethodSig = - adapter.getMethodSignatureForEmptyTypeMap(parentMethodDecl) - val tableEntry = BindingTableEntry - .apply( - directTableEntry.name, - erasedParentMethodSig, - directTableEntry.implementingMethodFullName - ) - bindingTable.add(tableEntry) - } - } - } - - bindingTable - end createBindingTable + def createBindingTable[InputTypeDecl, AstTypeDecl, AstMethodDecl, TypeMap]( + typeDeclFullName: String, + typeDecl: InputTypeDecl, + getBindingTable: AstTypeDecl => BindingTable, + adapter: BindingTableAdapter[InputTypeDecl, AstTypeDecl, AstMethodDecl, TypeMap] + ): BindingTable = + val bindingTable = new BindingTable() + + // Take over all binding table entries for parent class/interface binding tables. + adapter.directParents(typeDecl).filterNot(adapter.typeDeclEquals(_, typeDecl)).foreach { + parentTypeDecl => + val parentBindingTable = + try + getBindingTable(parentTypeDecl) + catch + case e: StackOverflowError => + throw new RuntimeException( + s"SOE getting binding table for $typeDeclFullName" + ) + parentBindingTable.getEntries.foreach { entry => + bindingTable.add(entry) + } + } + + // Create table entries for all methods declared in type declaration. + val directTableEntries = adapter.directBindingTableEntries(typeDeclFullName, typeDecl) + + // Add all table entries for method of type declaration to binding table. + // It is important that this happens after adding the inherited entries + // because later entries for the same slot (same name and signature) + // override previously added entries. + directTableEntries.foreach(bindingTable.add) + + // Override the bindings for generic base class methods if they are overriden. + // To do so we need to traverse all methods in all parent type and calculate + // their signature in the derived type declarations context, meaning with the + // concrete values for the generic type parameters. If this signature together + // with the name matches a direct table entry we have an override and replace + // the binding table entry for the erased! parent method signature. + // This become necessary because calls in the JVM executed via erased signatures. + adapter.allParentsWithTypeMap(typeDecl).foreach { + case (parentTypeDecl, typeParameterInDerivedContext) => + directTableEntries.foreach { directTableEntry => + val parentMethods = adapter.getDeclaredMethods(parentTypeDecl) + parentMethods.foreach { case (parentName, parentMethodDecl) => + if directTableEntry.name == parentName then + val parentSigInDerivedContext = adapter.getMethodSignature( + parentMethodDecl, + typeParameterInDerivedContext + ) + if directTableEntry.signature == parentSigInDerivedContext then + val erasedParentMethodSig = + adapter.getMethodSignatureForEmptyTypeMap(parentMethodDecl) + val tableEntry = BindingTableEntry + .apply( + directTableEntry.name, + erasedParentMethodSig, + directTableEntry.implementingMethodFullName + ) + bindingTable.add(tableEntry) + } + } + } + + bindingTable + end createBindingTable end BindingTable diff --git a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/util/BindingTableAdapterImpls.scala b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/util/BindingTableAdapterImpls.scala index 4d0839ac..82d58358 100644 --- a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/util/BindingTableAdapterImpls.scala +++ b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/util/BindingTableAdapterImpls.scala @@ -16,16 +16,16 @@ import scala.jdk.OptionConverters.RichOptional import scala.jdk.CollectionConverters.* object Shared: - def getDeclaredMethods(typeDecl: ResolvedReferenceTypeDeclaration) - : Iterable[ResolvedMethodDeclaration] = - typeDecl match - // Attempting to get declared methods for annotations throws an UnsupportedOperationException. - // See https://github.com/javaparser/javaparser/issues/1838 for details. - case _: JavaParserAnnotationDeclaration => Set.empty - case _: ReflectionAnnotationDeclaration => Set.empty - case _: JavassistAnnotationDeclaration => Set.empty + def getDeclaredMethods(typeDecl: ResolvedReferenceTypeDeclaration) + : Iterable[ResolvedMethodDeclaration] = + typeDecl match + // Attempting to get declared methods for annotations throws an UnsupportedOperationException. + // See https://github.com/javaparser/javaparser/issues/1838 for details. + case _: JavaParserAnnotationDeclaration => Set.empty + case _: ReflectionAnnotationDeclaration => Set.empty + case _: JavassistAnnotationDeclaration => Set.empty - case _ => typeDecl.getDeclaredMethods.asScala + case _ => typeDecl.getDeclaredMethods.asScala class BindingTableAdapterForJavaparser( methodSignatureImpl: (ResolvedMethodDeclaration, ResolvedTypeParametersMap) => String @@ -36,53 +36,53 @@ class BindingTableAdapterForJavaparser( ResolvedTypeParametersMap ]: - override def directParents( - typeDecl: ResolvedReferenceTypeDeclaration - ): collection.Seq[ResolvedReferenceTypeDeclaration] = - safeGetAncestors(typeDecl).map(_.getTypeDeclaration.get) - - override def allParentsWithTypeMap( - typeDecl: ResolvedReferenceTypeDeclaration - ): collection.Seq[(ResolvedReferenceTypeDeclaration, ResolvedTypeParametersMap)] = - getAllParents(typeDecl).map { parentType => - (parentType.getTypeDeclaration.get, parentType.typeParametersMap()) - } - - override def directBindingTableEntries( - typeDeclFullName: String, - typeDecl: ResolvedReferenceTypeDeclaration - ): collection.Seq[BindingTableEntry] = - getDeclaredMethods(typeDecl) - .filter { case (_, methodDecl) => !methodDecl.isStatic } - .map { case (_, methodDecl) => - val signature = getMethodSignature(methodDecl, ResolvedTypeParametersMap.empty()) - BindingTableEntry.apply( - methodDecl.getName, - signature, - composeMethodFullName(typeDeclFullName, methodDecl.getName, signature) - ) - } - .toBuffer - - override def getDeclaredMethods( - typeDecl: ResolvedReferenceTypeDeclaration - ): Iterable[(String, ResolvedMethodDeclaration)] = - Shared.getDeclaredMethods(typeDecl).map(method => (method.getName, method)) - - override def getMethodSignature( - methodDecl: ResolvedMethodDeclaration, - typeMap: ResolvedTypeParametersMap - ): String = - methodSignatureImpl(methodDecl, typeMap) - - override def getMethodSignatureForEmptyTypeMap(methodDecl: ResolvedMethodDeclaration): String = - methodSignatureImpl(methodDecl, ResolvedTypeParametersMap.empty()) - - override def typeDeclEquals( - astTypeDecl: ResolvedReferenceTypeDeclaration, - inputTypeDecl: ResolvedReferenceTypeDeclaration - ): Boolean = - astTypeDecl.getQualifiedName == inputTypeDecl.getQualifiedName + override def directParents( + typeDecl: ResolvedReferenceTypeDeclaration + ): collection.Seq[ResolvedReferenceTypeDeclaration] = + safeGetAncestors(typeDecl).map(_.getTypeDeclaration.get) + + override def allParentsWithTypeMap( + typeDecl: ResolvedReferenceTypeDeclaration + ): collection.Seq[(ResolvedReferenceTypeDeclaration, ResolvedTypeParametersMap)] = + getAllParents(typeDecl).map { parentType => + (parentType.getTypeDeclaration.get, parentType.typeParametersMap()) + } + + override def directBindingTableEntries( + typeDeclFullName: String, + typeDecl: ResolvedReferenceTypeDeclaration + ): collection.Seq[BindingTableEntry] = + getDeclaredMethods(typeDecl) + .filter { case (_, methodDecl) => !methodDecl.isStatic } + .map { case (_, methodDecl) => + val signature = getMethodSignature(methodDecl, ResolvedTypeParametersMap.empty()) + BindingTableEntry.apply( + methodDecl.getName, + signature, + composeMethodFullName(typeDeclFullName, methodDecl.getName, signature) + ) + } + .toBuffer + + override def getDeclaredMethods( + typeDecl: ResolvedReferenceTypeDeclaration + ): Iterable[(String, ResolvedMethodDeclaration)] = + Shared.getDeclaredMethods(typeDecl).map(method => (method.getName, method)) + + override def getMethodSignature( + methodDecl: ResolvedMethodDeclaration, + typeMap: ResolvedTypeParametersMap + ): String = + methodSignatureImpl(methodDecl, typeMap) + + override def getMethodSignatureForEmptyTypeMap(methodDecl: ResolvedMethodDeclaration): String = + methodSignatureImpl(methodDecl, ResolvedTypeParametersMap.empty()) + + override def typeDeclEquals( + astTypeDecl: ResolvedReferenceTypeDeclaration, + inputTypeDecl: ResolvedReferenceTypeDeclaration + ): Boolean = + astTypeDecl.getQualifiedName == inputTypeDecl.getQualifiedName end BindingTableAdapterForJavaparser case class LambdaBindingInfo( @@ -100,46 +100,46 @@ class BindingTableAdapterForLambdas( ResolvedTypeParametersMap ]: - override def directParents(lambdaBindingInfo: LambdaBindingInfo) - : collection.Seq[ResolvedReferenceTypeDeclaration] = - lambdaBindingInfo.implementedType.flatMap(_.getTypeDeclaration.toScala).toList - - override def allParentsWithTypeMap( - lambdaBindingInfo: LambdaBindingInfo - ): collection.Seq[(ResolvedReferenceTypeDeclaration, ResolvedTypeParametersMap)] = - val nonDirectParents = - lambdaBindingInfo.implementedType.flatMap(_.getTypeDeclaration.toScala).toList.flatMap( - getAllParents - ) - (lambdaBindingInfo.implementedType.toList ++ nonDirectParents).map { typ => - (typ.getTypeDeclaration.get, typ.typeParametersMap()) - } - - override def directBindingTableEntries( - typeDeclFullName: String, - lambdaBindingInfo: LambdaBindingInfo - ): collection.Seq[BindingTableEntry] = - lambdaBindingInfo.directBinding.map { binding => - BindingTableEntry(binding.name, binding.signature, binding.methodFullName) - }.toList - - override def getDeclaredMethods( - typeDecl: ResolvedReferenceTypeDeclaration - ): Iterable[(String, ResolvedMethodDeclaration)] = - Shared.getDeclaredMethods(typeDecl).map(method => (method.getName, method)) - - override def getMethodSignature( - methodDecl: ResolvedMethodDeclaration, - typeMap: ResolvedTypeParametersMap - ): String = - methodSignatureImpl(methodDecl, typeMap) - - override def getMethodSignatureForEmptyTypeMap(methodDecl: ResolvedMethodDeclaration): String = - methodSignatureImpl(methodDecl, ResolvedTypeParametersMap.empty()) - - override def typeDeclEquals( - astTypeDecl: ResolvedReferenceTypeDeclaration, - inputTypeDecl: LambdaBindingInfo - ): Boolean = - astTypeDecl.getQualifiedName == inputTypeDecl.fullName + override def directParents(lambdaBindingInfo: LambdaBindingInfo) + : collection.Seq[ResolvedReferenceTypeDeclaration] = + lambdaBindingInfo.implementedType.flatMap(_.getTypeDeclaration.toScala).toList + + override def allParentsWithTypeMap( + lambdaBindingInfo: LambdaBindingInfo + ): collection.Seq[(ResolvedReferenceTypeDeclaration, ResolvedTypeParametersMap)] = + val nonDirectParents = + lambdaBindingInfo.implementedType.flatMap(_.getTypeDeclaration.toScala).toList.flatMap( + getAllParents + ) + (lambdaBindingInfo.implementedType.toList ++ nonDirectParents).map { typ => + (typ.getTypeDeclaration.get, typ.typeParametersMap()) + } + + override def directBindingTableEntries( + typeDeclFullName: String, + lambdaBindingInfo: LambdaBindingInfo + ): collection.Seq[BindingTableEntry] = + lambdaBindingInfo.directBinding.map { binding => + BindingTableEntry(binding.name, binding.signature, binding.methodFullName) + }.toList + + override def getDeclaredMethods( + typeDecl: ResolvedReferenceTypeDeclaration + ): Iterable[(String, ResolvedMethodDeclaration)] = + Shared.getDeclaredMethods(typeDecl).map(method => (method.getName, method)) + + override def getMethodSignature( + methodDecl: ResolvedMethodDeclaration, + typeMap: ResolvedTypeParametersMap + ): String = + methodSignatureImpl(methodDecl, typeMap) + + override def getMethodSignatureForEmptyTypeMap(methodDecl: ResolvedMethodDeclaration): String = + methodSignatureImpl(methodDecl, ResolvedTypeParametersMap.empty()) + + override def typeDeclEquals( + astTypeDecl: ResolvedReferenceTypeDeclaration, + inputTypeDecl: LambdaBindingInfo + ): Boolean = + astTypeDecl.getQualifiedName == inputTypeDecl.fullName end BindingTableAdapterForLambdas diff --git a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/util/Delombok.scala b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/util/Delombok.scala index 4a75e0a3..ede3184a 100644 --- a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/util/Delombok.scala +++ b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/util/Delombok.scala @@ -10,80 +10,80 @@ import scala.util.{Failure, Success, Try} object Delombok: - sealed trait DelombokMode - // Don't run delombok at all. - object DelombokMode: - case object NoDelombok extends DelombokMode - case object Default extends DelombokMode - case object TypesOnly extends DelombokMode - case object RunDelombok extends DelombokMode + sealed trait DelombokMode + // Don't run delombok at all. + object DelombokMode: + case object NoDelombok extends DelombokMode + case object Default extends DelombokMode + case object TypesOnly extends DelombokMode + case object RunDelombok extends DelombokMode - private val logger = LoggerFactory.getLogger(this.getClass) + private val logger = LoggerFactory.getLogger(this.getClass) - private def systemJavaPath: String = - sys.env - .get("JAVA_HOME") - .flatMap { javaHome => - val javaExecutable = File(javaHome, "bin", "java") - Option.when(javaExecutable.exists && javaExecutable.isExecutable) { - javaExecutable.canonicalPath - } + private def systemJavaPath: String = + sys.env + .get("JAVA_HOME") + .flatMap { javaHome => + val javaExecutable = File(javaHome, "bin", "java") + Option.when(javaExecutable.exists && javaExecutable.isExecutable) { + javaExecutable.canonicalPath } - .getOrElse("java") + } + .getOrElse("java") - private def delombokToTempDirCommand(tempDir: File, analysisJavaHome: Option[String]) = - val javaPath = analysisJavaHome.getOrElse(systemJavaPath) - val classPathArg = Try(File.newTemporaryFile("classpath").deleteOnExit()) match - case Success(file) => - if System.getProperty("java.class.path").nonEmpty then - // Write classpath to a file to work around Windows length limits. - file.write(System.getProperty("java.class.path")) - s"@${file.canonicalPath}" - else System.getProperty("java.class.path") - case Failure(t) => - logger.debug( - s"Failed to create classpath file for delombok execution. Results may be missing on Windows systems" - ) - System.getProperty("java.class.path") - if classPathArg.nonEmpty then - s"$javaPath -cp $classPathArg lombok.launch.Main delombok . -d ${tempDir.canonicalPath}" - else "" + private def delombokToTempDirCommand(tempDir: File, analysisJavaHome: Option[String]) = + val javaPath = analysisJavaHome.getOrElse(systemJavaPath) + val classPathArg = Try(File.newTemporaryFile("classpath").deleteOnExit()) match + case Success(file) => + if System.getProperty("java.class.path").nonEmpty then + // Write classpath to a file to work around Windows length limits. + file.write(System.getProperty("java.class.path")) + s"@${file.canonicalPath}" + else System.getProperty("java.class.path") + case Failure(t) => + logger.debug( + s"Failed to create classpath file for delombok execution. Results may be missing on Windows systems" + ) + System.getProperty("java.class.path") + if classPathArg.nonEmpty then + s"$javaPath -cp $classPathArg lombok.launch.Main delombok . -d ${tempDir.canonicalPath}" + else "" - def run(projectDir: String, analysisJavaHome: Option[String]): String = - Try(File.newTemporaryDirectory(prefix = "delombok").deleteOnExit()) match - case Success(tempDir) => - val externalCommand = delombokToTempDirCommand(tempDir, analysisJavaHome) - if externalCommand.nonEmpty then - ExternalCommand.run( - externalCommand, - cwd = projectDir - ) match - case Success(_) => - tempDir.path.toAbsolutePath.toString + def run(projectDir: String, analysisJavaHome: Option[String]): String = + Try(File.newTemporaryDirectory(prefix = "delombok").deleteOnExit()) match + case Success(tempDir) => + val externalCommand = delombokToTempDirCommand(tempDir, analysisJavaHome) + if externalCommand.nonEmpty then + ExternalCommand.run( + externalCommand, + cwd = projectDir + ) match + case Success(_) => + tempDir.path.toAbsolutePath.toString - case Failure(t) => - logger.debug(s"Executing delombok failed", t) - logger.debug( - "Creating AST with original source instead. Some methods and type information will be missing." - ) - projectDir - else "" + case Failure(t) => + logger.debug(s"Executing delombok failed", t) + logger.debug( + "Creating AST with original source instead. Some methods and type information will be missing." + ) + projectDir + else "" - case Failure(e) => - logger.debug( - s"Failed to create temporary directory for delomboked source. Methods and types may be missing", - e - ) - projectDir + case Failure(e) => + logger.debug( + s"Failed to create temporary directory for delomboked source. Methods and types may be missing", + e + ) + projectDir - def parseDelombokModeOption(delombokModeStr: Option[String]): DelombokMode = - delombokModeStr.map(_.toLowerCase) match - case None => Default - case Some("no-delombok") => NoDelombok - case Some("default") => Default - case Some("types-only") => TypesOnly - case Some("run-delombok") => RunDelombok - case Some(value) => - logger.debug(s"Found unrecognised delombok mode `$value`. Using default instead.") - Default + def parseDelombokModeOption(delombokModeStr: Option[String]): DelombokMode = + delombokModeStr.map(_.toLowerCase) match + case None => Default + case Some("no-delombok") => NoDelombok + case Some("default") => Default + case Some("types-only") => TypesOnly + case Some("run-delombok") => RunDelombok + case Some(value) => + logger.debug(s"Found unrecognised delombok mode `$value`. Using default instead.") + Default end Delombok diff --git a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/util/NameConstants.scala b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/util/NameConstants.scala index d1020e2f..8973886f 100644 --- a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/util/NameConstants.scala +++ b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/util/NameConstants.scala @@ -1,6 +1,6 @@ package io.appthreat.javasrc2cpg.util object NameConstants: - val Super: String = "super" - val This: String = "this" - val WildcardImportName: String = "*" + val Super: String = "super" + val This: String = "this" + val WildcardImportName: String = "*" diff --git a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/util/SourceParser.scala b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/util/SourceParser.scala index a83936c5..04c539a9 100644 --- a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/util/SourceParser.scala +++ b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/util/SourceParser.scala @@ -24,105 +24,105 @@ import scala.util.Success class SourceParser private (originalInputPath: Path, analysisRoot: Path, typesRoot: Path): - /** Parse the given file into a JavaParser CompliationUnit that will be used for creating the - * CPG AST. - * - * @param relativeFilename - * path to the input file relative to the project root. - */ - def parseAnalysisFile(relativeFilename: String): Option[CompilationUnit] = - val analysisFilename = analysisRoot.resolve(relativeFilename).toString - // Need to store tokens for position information. - fileIfExists(analysisFilename).flatMap(parse(_, storeTokens = true)) - - /** Parse the given file into a JavaParser CompliationUnit that will be used for reading type - * information. These should not be used for determining the structure of the AST. - * - * @param relativeFilename - * path to the input file relative to the project root. - */ - def parseTypesFile(relativeFilename: String): Option[CompilationUnit] = - val typesFilename = typesRoot.resolve(relativeFilename).toString - fileIfExists(typesFilename).flatMap(parse(_, storeTokens = false)) - - def fileIfExists(filename: String): Option[File] = - val file = File(filename) - - Option.when(file.exists)(file) - - def getTypesFileLines(relativeFilename: String): Try[Iterable[String]] = - val typesFilename = typesRoot.resolve(relativeFilename).toString - Try(File(typesFilename).lines(Charset.defaultCharset())) - .orElse(Try(File(typesFilename).lines(StandardCharsets.ISO_8859_1))) - - def doesTypesFileExist(relativeFilename: String): Boolean = - File(typesRoot.resolve(relativeFilename)).isRegularFile - - private def parse(file: File, storeTokens: Boolean): Option[CompilationUnit] = - val javaParserConfig = - new ParserConfiguration() - .setLanguageLevel(LanguageLevel.BLEEDING_EDGE) - .setAttributeComments(false) - .setLexicalPreservationEnabled(true) - .setStoreTokens(storeTokens) - val parseResult = new JavaParser(javaParserConfig).parse(file.toJava) - - parseResult.getResult.toScala match - case Some(result) if result.getParsed == Parsedness.PARSED => Some(result) - case _ => - None + /** Parse the given file into a JavaParser CompliationUnit that will be used for creating the CPG + * AST. + * + * @param relativeFilename + * path to the input file relative to the project root. + */ + def parseAnalysisFile(relativeFilename: String): Option[CompilationUnit] = + val analysisFilename = analysisRoot.resolve(relativeFilename).toString + // Need to store tokens for position information. + fileIfExists(analysisFilename).flatMap(parse(_, storeTokens = true)) + + /** Parse the given file into a JavaParser CompliationUnit that will be used for reading type + * information. These should not be used for determining the structure of the AST. + * + * @param relativeFilename + * path to the input file relative to the project root. + */ + def parseTypesFile(relativeFilename: String): Option[CompilationUnit] = + val typesFilename = typesRoot.resolve(relativeFilename).toString + fileIfExists(typesFilename).flatMap(parse(_, storeTokens = false)) + + def fileIfExists(filename: String): Option[File] = + val file = File(filename) + + Option.when(file.exists)(file) + + def getTypesFileLines(relativeFilename: String): Try[Iterable[String]] = + val typesFilename = typesRoot.resolve(relativeFilename).toString + Try(File(typesFilename).lines(Charset.defaultCharset())) + .orElse(Try(File(typesFilename).lines(StandardCharsets.ISO_8859_1))) + + def doesTypesFileExist(relativeFilename: String): Boolean = + File(typesRoot.resolve(relativeFilename)).isRegularFile + + private def parse(file: File, storeTokens: Boolean): Option[CompilationUnit] = + val javaParserConfig = + new ParserConfiguration() + .setLanguageLevel(LanguageLevel.BLEEDING_EDGE) + .setAttributeComments(false) + .setLexicalPreservationEnabled(true) + .setStoreTokens(storeTokens) + val parseResult = new JavaParser(javaParserConfig).parse(file.toJava) + + parseResult.getResult.toScala match + case Some(result) if result.getParsed == Parsedness.PARSED => Some(result) + case _ => + None end SourceParser object SourceParser: - def apply(config: Config, hasLombokDependency: Boolean): SourceParser = - val canonicalInputPath = File(config.inputPath).canonicalPath - val (analysisDir, typesDir) = - getAnalysisAndTypesDirs( - canonicalInputPath, - config.delombokJavaHome, - config.delombokMode, - hasLombokDependency - ) - new SourceParser(Path.of(canonicalInputPath), Path.of(analysisDir), Path.of(typesDir)) - - def getSourceFilenames( - config: Config, - sourcesOverride: Option[List[String]] = None - ): Array[String] = - val inputPaths = sourcesOverride.getOrElse(config.inputPath :: Nil).toSet - SourceFiles.determine(inputPaths, JavaSrc2Cpg.sourceFileExtensions, config).toArray - - /** Implements the logic described in the option description for the "delombok-mode" option: - * - no-delombok: do not run delombok. - * - default: run delombok if a lombok dependency is found and analyse delomboked code. - * - types-only: run delombok, but use it for type information only - * - run-delombok: run delombok and analyse delomboked code - * - * @return - * the tuple (analysisRoot, typesRoot) where analysisRoot is used to locate source files for - * creating the AST and typesRoot is used for locating source files from which to extract - * type information. - */ - private def getAnalysisAndTypesDirs( - originalDir: String, - delombokJavaHome: Option[String], - delombokMode: Option[String], - hasLombokDependency: Boolean - ): (String, String) = - lazy val delombokDir = Delombok.run(originalDir, delombokJavaHome) - if delombokDir.nonEmpty then - Delombok.parseDelombokModeOption(delombokMode) match - case Default if hasLombokDependency => - (delombokDir, delombokDir) - - case Default => (originalDir, originalDir) - - case NoDelombok => (originalDir, originalDir) - - case TypesOnly => (originalDir, delombokDir) - - case RunDelombok => (delombokDir, delombokDir) - else (delombokDir, delombokDir) - end getAnalysisAndTypesDirs + def apply(config: Config, hasLombokDependency: Boolean): SourceParser = + val canonicalInputPath = File(config.inputPath).canonicalPath + val (analysisDir, typesDir) = + getAnalysisAndTypesDirs( + canonicalInputPath, + config.delombokJavaHome, + config.delombokMode, + hasLombokDependency + ) + new SourceParser(Path.of(canonicalInputPath), Path.of(analysisDir), Path.of(typesDir)) + + def getSourceFilenames( + config: Config, + sourcesOverride: Option[List[String]] = None + ): Array[String] = + val inputPaths = sourcesOverride.getOrElse(config.inputPath :: Nil).toSet + SourceFiles.determine(inputPaths, JavaSrc2Cpg.sourceFileExtensions, config).toArray + + /** Implements the logic described in the option description for the "delombok-mode" option: + * - no-delombok: do not run delombok. + * - default: run delombok if a lombok dependency is found and analyse delomboked code. + * - types-only: run delombok, but use it for type information only + * - run-delombok: run delombok and analyse delomboked code + * + * @return + * the tuple (analysisRoot, typesRoot) where analysisRoot is used to locate source files for + * creating the AST and typesRoot is used for locating source files from which to extract type + * information. + */ + private def getAnalysisAndTypesDirs( + originalDir: String, + delombokJavaHome: Option[String], + delombokMode: Option[String], + hasLombokDependency: Boolean + ): (String, String) = + lazy val delombokDir = Delombok.run(originalDir, delombokJavaHome) + if delombokDir.nonEmpty then + Delombok.parseDelombokModeOption(delombokMode) match + case Default if hasLombokDependency => + (delombokDir, delombokDir) + + case Default => (originalDir, originalDir) + + case NoDelombok => (originalDir, originalDir) + + case TypesOnly => (originalDir, delombokDir) + + case RunDelombok => (delombokDir, delombokDir) + else (delombokDir, delombokDir) + end getAnalysisAndTypesDirs end SourceParser diff --git a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/util/SourceRootFinder.scala b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/util/SourceRootFinder.scala index ee80e2f1..7e5488ca 100644 --- a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/util/SourceRootFinder.scala +++ b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/util/SourceRootFinder.scala @@ -14,75 +14,75 @@ import better.files.File */ object SourceRootFinder: - private val excludes: Set[String] = Set("test", ".mvn", ".git") + private val excludes: Set[String] = Set("test", ".mvn", ".git") - private def statePreSrc(currentDir: File): List[File] = - currentDir.children.filter(_.isDirectory).toList.flatMap { child => - child.name match - case name if excludes.contains(name) => Nil - case "src" => stateSrc(child) - case "main" => statePostSrc(child) - case "java" => child :: Nil - case _ => statePreSrc(child) - } + private def statePreSrc(currentDir: File): List[File] = + currentDir.children.filter(_.isDirectory).toList.flatMap { child => + child.name match + case name if excludes.contains(name) => Nil + case "src" => stateSrc(child) + case "main" => statePostSrc(child) + case "java" => child :: Nil + case _ => statePreSrc(child) + } - private def stateSrc(currentDir: File): List[File] = - val mainChildren = currentDir.children.filter(_.isDirectory).toList.flatMap { child => - child.name match - case name if excludes.contains(name) => Nil - case "main" => statePostSrc(child) - case _ => Nil - } + private def stateSrc(currentDir: File): List[File] = + val mainChildren = currentDir.children.filter(_.isDirectory).toList.flatMap { child => + child.name match + case name if excludes.contains(name) => Nil + case "main" => statePostSrc(child) + case _ => Nil + } - val nonMainChildren = currentDir.children.filter(_.isDirectory).toList.flatMap { child => - child.name match - case name if excludes.contains(name) => Nil - case "main" => Nil - case _ => statePostSrc(child) - } + val nonMainChildren = currentDir.children.filter(_.isDirectory).toList.flatMap { child => + child.name match + case name if excludes.contains(name) => Nil + case "main" => Nil + case _ => statePostSrc(child) + } - val hasExcludedDir = currentDir.children.filter(_.isDirectory).toList.exists(file => - excludes.contains(file.name) - ) + val hasExcludedDir = currentDir.children.filter(_.isDirectory).toList.exists(file => + excludes.contains(file.name) + ) - (mainChildren, nonMainChildren) match - case (Nil, Nil) => - // probably a src/test directory in a tests-only module - if hasExcludedDir then Nil else List(currentDir) - case (mainC, Nil) => - // probably follows the common src/main/ structure - mainC - case (Nil, _) => - // the non-main children are probably package roots - List(currentDir) - case (mainC, nonMC) => - // main a package root here? - mainC ++ nonMC - end stateSrc + (mainChildren, nonMainChildren) match + case (Nil, Nil) => + // probably a src/test directory in a tests-only module + if hasExcludedDir then Nil else List(currentDir) + case (mainC, Nil) => + // probably follows the common src/main/ structure + mainC + case (Nil, _) => + // the non-main children are probably package roots + List(currentDir) + case (mainC, nonMC) => + // main a package root here? + mainC ++ nonMC + end stateSrc - private def statePostSrc(currentDir: File): List[File] = - val javaChildren = currentDir.children.filter(_.isDirectory).toList.flatMap { child => - child.name match - case name if excludes.contains(name) => Nil - case _ => child :: Nil - } + private def statePostSrc(currentDir: File): List[File] = + val javaChildren = currentDir.children.filter(_.isDirectory).toList.flatMap { child => + child.name match + case name if excludes.contains(name) => Nil + case _ => child :: Nil + } - javaChildren match - case Nil => currentDir :: Nil - case _ => javaChildren + javaChildren match + case Nil => currentDir :: Nil + case _ => javaChildren - private def listBottomLevelSubdirectories(currentDir: File): List[File] = - val srcDirs = currentDir.name match - case name if excludes.contains(name) => Nil - case "src" => stateSrc(currentDir) - case "main" => statePostSrc(currentDir) - case "java" => List(currentDir) - case _ => statePreSrc(currentDir) + private def listBottomLevelSubdirectories(currentDir: File): List[File] = + val srcDirs = currentDir.name match + case name if excludes.contains(name) => Nil + case "src" => stateSrc(currentDir) + case "main" => statePostSrc(currentDir) + case "java" => List(currentDir) + case _ => statePreSrc(currentDir) - srcDirs match - case Nil => List(currentDir) - case _ => srcDirs + srcDirs match + case Nil => List(currentDir) + case _ => srcDirs - def getSourceRoots(codeDir: String): List[String] = - listBottomLevelSubdirectories(File(codeDir)).map(_.pathAsString) + def getSourceRoots(codeDir: String): List[String] = + listBottomLevelSubdirectories(File(codeDir)).map(_.pathAsString) end SourceRootFinder diff --git a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/util/Util.scala b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/util/Util.scala index 3033d8ac..db785caf 100644 --- a/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/util/Util.scala +++ b/platform/frontends/javasrc2cpg/src/main/scala/io/appthreat/javasrc2cpg/util/Util.scala @@ -14,53 +14,53 @@ import scala.jdk.CollectionConverters.* object Util: - private val logger = LoggerFactory.getLogger(this.getClass) - def composeMethodFullName(typeDeclFullName: String, name: String, signature: String): String = - s"$typeDeclFullName.$name:$signature" + private val logger = LoggerFactory.getLogger(this.getClass) + def composeMethodFullName(typeDeclFullName: String, name: String, signature: String): String = + s"$typeDeclFullName.$name:$signature" - def safeGetAncestors(typeDecl: ResolvedReferenceTypeDeclaration): Seq[ResolvedReferenceType] = - Try(typeDecl.getAncestors(true)) match - case Success(ancestors) => ancestors.asScala.filterNot(_ == typeDecl).toSeq + def safeGetAncestors(typeDecl: ResolvedReferenceTypeDeclaration): Seq[ResolvedReferenceType] = + Try(typeDecl.getAncestors(true)) match + case Success(ancestors) => ancestors.asScala.filterNot(_ == typeDecl).toSeq - case Failure(exception) => - logger.debug( - s"Failed to get direct parents for typeDecl ${typeDecl.getQualifiedName}", - exception - ) - Seq.empty + case Failure(exception) => + logger.debug( + s"Failed to get direct parents for typeDecl ${typeDecl.getQualifiedName}", + exception + ) + Seq.empty - def getAllParents(typeDecl: ResolvedReferenceTypeDeclaration) - : mutable.ArrayBuffer[ResolvedReferenceType] = - val result = mutable.ArrayBuffer.empty[ResolvedReferenceType] + def getAllParents(typeDecl: ResolvedReferenceTypeDeclaration) + : mutable.ArrayBuffer[ResolvedReferenceType] = + val result = mutable.ArrayBuffer.empty[ResolvedReferenceType] - if !typeDecl.isJavaLangObject then - safeGetAncestors(typeDecl).filter( - _.getQualifiedName != typeDecl.getQualifiedName - ).foreach { ancestor => - result.append(ancestor) - getAllParents(ancestor, result) - } + if !typeDecl.isJavaLangObject then + safeGetAncestors(typeDecl).filter( + _.getQualifiedName != typeDecl.getQualifiedName + ).foreach { ancestor => + result.append(ancestor) + getAllParents(ancestor, result) + } - result + result - def composeMethodLikeSignature( - returnType: String, - parameterTypes: collection.Seq[String] - ): String = - s"$returnType(${parameterTypes.mkString(",")})" + def composeMethodLikeSignature( + returnType: String, + parameterTypes: collection.Seq[String] + ): String = + s"$returnType(${parameterTypes.mkString(",")})" - def composeUnresolvedSignature(paramCount: Int): String = - s"${Defines.UnresolvedSignature}($paramCount)" + def composeUnresolvedSignature(paramCount: Int): String = + s"${Defines.UnresolvedSignature}($paramCount)" - private def getAllParents( - typ: ResolvedReferenceType, - result: mutable.ArrayBuffer[ResolvedReferenceType] - ): Unit = - if typ.isJavaLangObject then - Iterable.empty - else - Try(typ.getDirectAncestors).map(_.asScala).getOrElse(Nil).foreach { ancestor => - result.append(ancestor) - getAllParents(ancestor, result) - } + private def getAllParents( + typ: ResolvedReferenceType, + result: mutable.ArrayBuffer[ResolvedReferenceType] + ): Unit = + if typ.isJavaLangObject then + Iterable.empty + else + Try(typ.getDirectAncestors).map(_.asScala).getOrElse(Nil).foreach { ancestor => + result.append(ancestor) + getAllParents(ancestor, result) + } end Util diff --git a/platform/frontends/jimple2cpg/src/main/scala/io/appthreat/jimple2cpg/Jimple2Cpg.scala b/platform/frontends/jimple2cpg/src/main/scala/io/appthreat/jimple2cpg/Jimple2Cpg.scala index 2bea35db..4f15db0e 100644 --- a/platform/frontends/jimple2cpg/src/main/scala/io/appthreat/jimple2cpg/Jimple2Cpg.scala +++ b/platform/frontends/jimple2cpg/src/main/scala/io/appthreat/jimple2cpg/Jimple2Cpg.scala @@ -23,135 +23,135 @@ import scala.language.postfixOps import scala.util.Try object Jimple2Cpg: - val language = "JAVA" + val language = "JAVA" - def apply(): Jimple2Cpg = new Jimple2Cpg() + def apply(): Jimple2Cpg = new Jimple2Cpg() class Jimple2Cpg extends X2CpgFrontend[Config]: - import Jimple2Cpg.* - - private val logger = LoggerFactory.getLogger(classOf[Jimple2Cpg]) - - private def sootLoadApk(input: File, framework: Option[String] = None): Unit = - Options.v().set_process_dir(List(input.canonicalPath).asJava) - framework match - case Some(value) if value.nonEmpty => - Options.v().set_src_prec(Options.src_prec_apk) - Options.v().set_force_android_jar(value) - case _ => - Options.v().set_src_prec(Options.src_prec_apk_c_j) - Options.v().set_process_multiple_dex(true) - // workaround for Soot's bug while parsing large apk. - // see: https://github.com/soot-oss/soot/issues/1256 - Options.v().setPhaseOption("jb", "use-original-names:false") - - /** Load all class files from archives or directories recursively - * @param recurse - * Whether to unpack recursively - * @return - * The list of extracted class files whose package path could be extracted, placed on that - * package path relative to [[tmpDir]] - */ - private def loadClassFiles(src: File, tmpDir: File, recurse: Boolean): List[ClassFile] = - val archiveFileExtensions = Set(".jar", ".war", ".zip") - extractClassesInPackageLayout( - src, - tmpDir, - isClass = e => e.extension.contains(".class"), - isArchive = e => e.extension.exists(archiveFileExtensions.contains), - recurse - ) - - /** Extract all class files found, place them in their package layout and load them into soot. - * @param input - * The file/directory to traverse for class files. - * @param tmpDir - * The directory to place the class files in their package layout - * @param recurse - * Whether to unpack recursively - */ - private def sootLoad(input: File, tmpDir: File, recurse: Boolean): List[ClassFile] = - Options.v().set_soot_classpath(tmpDir.canonicalPath) - Options.v().set_prepend_classpath(true) - val classFiles = loadClassFiles(input, tmpDir, recurse) - val fullyQualifiedClassNames = classFiles.flatMap(_.fullyQualifiedClassName) - logger.debug(s"Loading ${classFiles.size} program files") - logger.debug(s"Source files are: ${classFiles.map(_.file.canonicalPath)}") - fullyQualifiedClassNames.foreach { fqcn => - Scene.v().addBasicClass(fqcn) - Scene.v().loadClassAndSupport(fqcn) - } - classFiles - - /** Apply the soot passes - * @param tmpDir - * A temporary directory that will be used as the classpath for extracted class files - */ - private def cpgApplyPasses(cpg: Cpg, config: Config, tmpDir: File): Unit = - val input = File(config.inputPath) - configureSoot(config, tmpDir) - new MetaDataPass(cpg, language, config.inputPath).createAndApply() - - val globalFromAstCreation: () => Global = input.extension match - case Some(".apk" | ".dex") if input.isRegularFile => - sootLoadApk(input, config.android) - { () => - val astCreator = SootAstCreationPass(cpg, config) - astCreator.createAndApply() - astCreator.global - } - case _ => - val classFiles = sootLoad(input, tmpDir, config.recurse) - { () => - val astCreator = AstCreationPass(classFiles, cpg, config) - astCreator.createAndApply() - astCreator.global - } - logger.debug("Loading classes to soot") - Scene.v().loadNecessaryClasses() - logger.debug(s"Loaded ${Scene.v().getApplicationClasses.size()} classes") - - val global = globalFromAstCreation() - TypeNodePass - .withRegisteredTypes(global.usedTypes.keys().asScala.toList, cpg) - .createAndApply() - DeclarationRefPass(cpg).createAndApply() - new ConfigFileCreationPass(cpg).createAndApply() - end cpgApplyPasses - - override def createCpg(config: Config): Try[Cpg] = - try - withNewEmptyCpg(config.outputPath, config: Config) { (cpg, config) => - File.temporaryDirectory("jimple2cpg-").apply { tmpDir => - cpgApplyPasses(cpg, config, tmpDir) - } - } - finally - G.reset() - - private def configureSoot(config: Config, outDir: File): Unit = - // set application mode - Options.v().set_app(false) - Options.v().set_whole_program(false) - // keep debugging info - Options.v().set_keep_line_number(true) - Options.v().set_keep_offset(true) - // ignore library code - Options.v().set_no_bodies_for_excluded(true) - Options.v().set_allow_phantom_refs(true) - // keep variable names - Options.v().setPhaseOption("jb.sils", "enabled:false") - Options.v().setPhaseOption("jb", "use-original-names:true") - // output jimple - Options.v().set_output_format(Options.output_format_jimple) - Options.v().set_output_dir(outDir.canonicalPath) - - Options.v().set_dynamic_dir(config.dynamicDirs.asJava) - Options.v().set_dynamic_package(config.dynamicPkgs.asJava) - - if config.fullResolver then - // full transitive resolution of all references - Options.v().set_full_resolver(true) - end configureSoot + import Jimple2Cpg.* + + private val logger = LoggerFactory.getLogger(classOf[Jimple2Cpg]) + + private def sootLoadApk(input: File, framework: Option[String] = None): Unit = + Options.v().set_process_dir(List(input.canonicalPath).asJava) + framework match + case Some(value) if value.nonEmpty => + Options.v().set_src_prec(Options.src_prec_apk) + Options.v().set_force_android_jar(value) + case _ => + Options.v().set_src_prec(Options.src_prec_apk_c_j) + Options.v().set_process_multiple_dex(true) + // workaround for Soot's bug while parsing large apk. + // see: https://github.com/soot-oss/soot/issues/1256 + Options.v().setPhaseOption("jb", "use-original-names:false") + + /** Load all class files from archives or directories recursively + * @param recurse + * Whether to unpack recursively + * @return + * The list of extracted class files whose package path could be extracted, placed on that + * package path relative to [[tmpDir]] + */ + private def loadClassFiles(src: File, tmpDir: File, recurse: Boolean): List[ClassFile] = + val archiveFileExtensions = Set(".jar", ".war", ".zip") + extractClassesInPackageLayout( + src, + tmpDir, + isClass = e => e.extension.contains(".class"), + isArchive = e => e.extension.exists(archiveFileExtensions.contains), + recurse + ) + + /** Extract all class files found, place them in their package layout and load them into soot. + * @param input + * The file/directory to traverse for class files. + * @param tmpDir + * The directory to place the class files in their package layout + * @param recurse + * Whether to unpack recursively + */ + private def sootLoad(input: File, tmpDir: File, recurse: Boolean): List[ClassFile] = + Options.v().set_soot_classpath(tmpDir.canonicalPath) + Options.v().set_prepend_classpath(true) + val classFiles = loadClassFiles(input, tmpDir, recurse) + val fullyQualifiedClassNames = classFiles.flatMap(_.fullyQualifiedClassName) + logger.debug(s"Loading ${classFiles.size} program files") + logger.debug(s"Source files are: ${classFiles.map(_.file.canonicalPath)}") + fullyQualifiedClassNames.foreach { fqcn => + Scene.v().addBasicClass(fqcn) + Scene.v().loadClassAndSupport(fqcn) + } + classFiles + + /** Apply the soot passes + * @param tmpDir + * A temporary directory that will be used as the classpath for extracted class files + */ + private def cpgApplyPasses(cpg: Cpg, config: Config, tmpDir: File): Unit = + val input = File(config.inputPath) + configureSoot(config, tmpDir) + new MetaDataPass(cpg, language, config.inputPath).createAndApply() + + val globalFromAstCreation: () => Global = input.extension match + case Some(".apk" | ".dex") if input.isRegularFile => + sootLoadApk(input, config.android) + { () => + val astCreator = SootAstCreationPass(cpg, config) + astCreator.createAndApply() + astCreator.global + } + case _ => + val classFiles = sootLoad(input, tmpDir, config.recurse) + { () => + val astCreator = AstCreationPass(classFiles, cpg, config) + astCreator.createAndApply() + astCreator.global + } + logger.debug("Loading classes to soot") + Scene.v().loadNecessaryClasses() + logger.debug(s"Loaded ${Scene.v().getApplicationClasses.size()} classes") + + val global = globalFromAstCreation() + TypeNodePass + .withRegisteredTypes(global.usedTypes.keys().asScala.toList, cpg) + .createAndApply() + DeclarationRefPass(cpg).createAndApply() + new ConfigFileCreationPass(cpg).createAndApply() + end cpgApplyPasses + + override def createCpg(config: Config): Try[Cpg] = + try + withNewEmptyCpg(config.outputPath, config: Config) { (cpg, config) => + File.temporaryDirectory("jimple2cpg-").apply { tmpDir => + cpgApplyPasses(cpg, config, tmpDir) + } + } + finally + G.reset() + + private def configureSoot(config: Config, outDir: File): Unit = + // set application mode + Options.v().set_app(false) + Options.v().set_whole_program(false) + // keep debugging info + Options.v().set_keep_line_number(true) + Options.v().set_keep_offset(true) + // ignore library code + Options.v().set_no_bodies_for_excluded(true) + Options.v().set_allow_phantom_refs(true) + // keep variable names + Options.v().setPhaseOption("jb.sils", "enabled:false") + Options.v().setPhaseOption("jb", "use-original-names:true") + // output jimple + Options.v().set_output_format(Options.output_format_jimple) + Options.v().set_output_dir(outDir.canonicalPath) + + Options.v().set_dynamic_dir(config.dynamicDirs.asJava) + Options.v().set_dynamic_package(config.dynamicPkgs.asJava) + + if config.fullResolver then + // full transitive resolution of all references + Options.v().set_full_resolver(true) + end configureSoot end Jimple2Cpg diff --git a/platform/frontends/jimple2cpg/src/main/scala/io/appthreat/jimple2cpg/Main.scala b/platform/frontends/jimple2cpg/src/main/scala/io/appthreat/jimple2cpg/Main.scala index 88bd9b2e..c3817e53 100644 --- a/platform/frontends/jimple2cpg/src/main/scala/io/appthreat/jimple2cpg/Main.scala +++ b/platform/frontends/jimple2cpg/src/main/scala/io/appthreat/jimple2cpg/Main.scala @@ -13,56 +13,56 @@ final case class Config( fullResolver: Boolean = false, recurse: Boolean = false ) extends X2CpgConfig[Config]: - def withAndroid(android: String): Config = - copy(android = Some(android)).withInheritedFields(this) - def withDynamicDirs(value: Seq[String]): Config = - copy(dynamicDirs = value).withInheritedFields(this) - def withDynamicPkgs(value: Seq[String]): Config = - copy(dynamicPkgs = value).withInheritedFields(this) - def withFullResolver(value: Boolean): Config = - copy(fullResolver = value).withInheritedFields(this) + def withAndroid(android: String): Config = + copy(android = Some(android)).withInheritedFields(this) + def withDynamicDirs(value: Seq[String]): Config = + copy(dynamicDirs = value).withInheritedFields(this) + def withDynamicPkgs(value: Seq[String]): Config = + copy(dynamicPkgs = value).withInheritedFields(this) + def withFullResolver(value: Boolean): Config = + copy(fullResolver = value).withInheritedFields(this) - def withRecurse(value: Boolean): Config = - copy(recurse = value) + def withRecurse(value: Boolean): Config = + copy(recurse = value) private object Frontend: - implicit val defaultConfig: Config = Config() + implicit val defaultConfig: Config = Config() - val cmdLineParser: OParser[Unit, Config] = - val builder = OParser.builder[Config] - import builder.* - OParser.sequence( - programName("jimple2cpg"), - opt[String]("android") - .text("Optional path to android.jar while processing apk file.") - .action((android, config) => config.withAndroid(android)), - opt[Unit]("full-resolver") - .text( - "enables full transitive resolution of all references found in all classes that are resolved" - ) - .action((_, config) => config.withFullResolver(true)), - opt[Unit]("recurse") - .text("recursively unpack jars") - .action((_, config) => config.withRecurse(true)), - opt[Seq[String]]("dynamic-dirs") - .valueName(",,...") - .text( - "Mark all class files in dirs as classes that may be loaded dynamically. Comma separated values for multiple directories." - ) - .action((dynamicDirs, config) => config.withDynamicDirs(dynamicDirs)), - opt[Seq[String]]("dynamic-pkgs") - .valueName(",,...") - .text( - "Marks all class files belonging to the package pkg or any of its subpackages as classes which the application may load dynamically. Comma separated values for multiple packages." - ) - .action((dynamicPkgs, config) => config.withDynamicPkgs(dynamicPkgs)) - ) - end cmdLineParser + val cmdLineParser: OParser[Unit, Config] = + val builder = OParser.builder[Config] + import builder.* + OParser.sequence( + programName("jimple2cpg"), + opt[String]("android") + .text("Optional path to android.jar while processing apk file.") + .action((android, config) => config.withAndroid(android)), + opt[Unit]("full-resolver") + .text( + "enables full transitive resolution of all references found in all classes that are resolved" + ) + .action((_, config) => config.withFullResolver(true)), + opt[Unit]("recurse") + .text("recursively unpack jars") + .action((_, config) => config.withRecurse(true)), + opt[Seq[String]]("dynamic-dirs") + .valueName(",,...") + .text( + "Mark all class files in dirs as classes that may be loaded dynamically. Comma separated values for multiple directories." + ) + .action((dynamicDirs, config) => config.withDynamicDirs(dynamicDirs)), + opt[Seq[String]]("dynamic-pkgs") + .valueName(",,...") + .text( + "Marks all class files belonging to the package pkg or any of its subpackages as classes which the application may load dynamically. Comma separated values for multiple packages." + ) + .action((dynamicPkgs, config) => config.withDynamicPkgs(dynamicPkgs)) + ) + end cmdLineParser end Frontend /** Entry point for command line CPG creator */ object Main extends X2CpgMain(cmdLineParser, new Jimple2Cpg()): - def run(config: Config, jimple2Cpg: Jimple2Cpg): Unit = - jimple2Cpg.run(config) + def run(config: Config, jimple2Cpg: Jimple2Cpg): Unit = + jimple2Cpg.run(config) diff --git a/platform/frontends/jimple2cpg/src/main/scala/io/appthreat/jimple2cpg/passes/AstCreationPass.scala b/platform/frontends/jimple2cpg/src/main/scala/io/appthreat/jimple2cpg/passes/AstCreationPass.scala index 5ac2f11f..8761a177 100644 --- a/platform/frontends/jimple2cpg/src/main/scala/io/appthreat/jimple2cpg/passes/AstCreationPass.scala +++ b/platform/frontends/jimple2cpg/src/main/scala/io/appthreat/jimple2cpg/passes/AstCreationPass.scala @@ -18,21 +18,21 @@ import soot.Scene class AstCreationPass(classFiles: List[ClassFile], cpg: Cpg, config: Config) extends ConcurrentWriterCpgPass[ClassFile](cpg): - val global: Global = new Global() - private val logger = LoggerFactory.getLogger(classOf[AstCreationPass]) + val global: Global = new Global() + private val logger = LoggerFactory.getLogger(classOf[AstCreationPass]) - override def generateParts(): Array[? <: AnyRef] = classFiles.toArray + override def generateParts(): Array[? <: AnyRef] = classFiles.toArray - override def runOnPart(builder: DiffGraphBuilder, classFile: ClassFile): Unit = - try - val sootClass = Scene.v().loadClassAndSupport(classFile.fullyQualifiedClassName.get) - sootClass.setApplicationClass() - val localDiff = AstCreator(classFile.file.canonicalPath, sootClass, global)( - config.schemaValidation - ).createAst() - builder.absorb(localDiff) - catch - case e: Exception => - logger.warn(s"Exception on AST creation for ${classFile.file.canonicalPath}", e) - Iterator() + override def runOnPart(builder: DiffGraphBuilder, classFile: ClassFile): Unit = + try + val sootClass = Scene.v().loadClassAndSupport(classFile.fullyQualifiedClassName.get) + sootClass.setApplicationClass() + val localDiff = AstCreator(classFile.file.canonicalPath, sootClass, global)( + config.schemaValidation + ).createAst() + builder.absorb(localDiff) + catch + case e: Exception => + logger.warn(s"Exception on AST creation for ${classFile.file.canonicalPath}", e) + Iterator() end AstCreationPass diff --git a/platform/frontends/jimple2cpg/src/main/scala/io/appthreat/jimple2cpg/passes/AstCreator.scala b/platform/frontends/jimple2cpg/src/main/scala/io/appthreat/jimple2cpg/passes/AstCreator.scala index f18cf4cb..f3260b41 100644 --- a/platform/frontends/jimple2cpg/src/main/scala/io/appthreat/jimple2cpg/passes/AstCreator.scala +++ b/platform/frontends/jimple2cpg/src/main/scala/io/appthreat/jimple2cpg/passes/AstCreator.scala @@ -24,1161 +24,1159 @@ class AstCreator(filename: String, cls: SootClass, global: Global)(implicit withSchemaValidation: ValidationMode ) extends AstCreatorBase(filename): - import AstCreator.* - - private val logger = LoggerFactory.getLogger(classOf[AstCreationPass]) - private val unitToAsts = mutable.HashMap[soot.Unit, Seq[Ast]]() - private val controlTargets = mutable.HashMap[Seq[Ast], soot.Unit]() - // There are many, but the popular ones should do https://en.wikipedia.org/wiki/List_of_JVM_languages - private val JVM_LANGS = HashSet("scala", "clojure", "groovy", "kotlin", "jython", "jruby") - - /** Add `typeName` to a global map and return it. The map is later passed to a pass that creates - * TYPE nodes for each key in the map. - */ - private def registerType(typeName: String): String = - global.usedTypes.put(typeName, true) - typeName - - /** Entry point of AST creation. Translates a compilation unit created by JavaParser into a - * DiffGraph containing the corresponding CPG AST. - */ - def createAst(): DiffGraphBuilder = - val astRoot = astForCompilationUnit(cls) - storeInDiffGraph(astRoot, diffGraph) - diffGraph - - /** Translate compilation unit into AST - */ - private def astForCompilationUnit(cls: SootClass): Ast = - val ast = astForPackageDeclaration(cls.getPackageName) - val namespaceBlockFullName = - ast.root.collect { case x: NewNamespaceBlock => x.fullName }.getOrElse("none") - ast.withChild(astForTypeDecl(cls.getType, namespaceBlockFullName)) - - /** Translate package declaration into AST consisting of a corresponding namespace block. - */ - private def astForPackageDeclaration(packageDecl: String): Ast = - val absolutePath = new java.io.File(filename).toPath.toAbsolutePath.normalize().toString - val name = packageDecl.split("\\.").lastOption.getOrElse("") - val namespaceBlock = NewNamespaceBlock() - .name(name) - .fullName(packageDecl) - Ast(namespaceBlock.filename(absolutePath).order(1)) - - /** Creates a list of all inherited classes and implemented interfaces. If there are none then a - * list with a single element 'java.lang.Object' is returned by default. Returns two lists in - * the form of (List[Super Classes], List[Interfaces]). - */ - private def inheritedAndImplementedClasses(clazz: SootClass): (List[String], List[String]) = - val implementsTypeFullName = clazz.getInterfaces.asScala.map { (i: SootClass) => - registerType(i.getType.toQuotedString) - }.toList - val inheritsFromTypeFullName = - if clazz.hasSuperclass && clazz.getSuperclass.getType.toQuotedString != "java.lang.Object" - then - List(registerType(clazz.getSuperclass.getType.toQuotedString)) - else if implementsTypeFullName.isEmpty then - List(registerType("java.lang.Object")) - else List() - - (inheritsFromTypeFullName, implementsTypeFullName) - - /** Creates the AST root for type declarations and acts as the entry point for method - * generation. - */ - private def astForTypeDecl(typ: RefType, namespaceBlockFullName: String): Ast = - val fullName = registerType(typ.toQuotedString) - val shortName = typ.getSootClass.getShortJavaStyleName - val clz = typ.getSootClass - val code = new mutable.StringBuilder() - - if clz.isPublic then code.append("public ") - else if clz.isPrivate then code.append("private ") - if clz.isStatic then code.append("static ") - if clz.isFinal then code.append("final ") - if clz.isInterface then code.append("interface ") - else if clz.isAbstract then code.append("abstract ") - if clz.isEnum then code.append("enum ") - if !clz.isInterface then code.append(s"class $shortName") - else code.append(shortName) - - val modifiers = astsForModifiers(clz) - val (inherited, implemented) = inheritedAndImplementedClasses(typ.getSootClass) - - if inherited.nonEmpty then code.append(s" extends ${inherited.mkString(", ")}") - if implemented.nonEmpty then code.append(s" implements ${implemented.mkString(", ")}") - - val typeDecl = NewTypeDecl() - .name(shortName) - .fullName(fullName) - .order(1) // Jimple always has 1 class per file - .filename(filename) - .code(code.toString()) - .inheritsFromTypeFullName(inherited ++ implemented) - .astParentType(NodeTypes.NAMESPACE_BLOCK) - .astParentFullName(namespaceBlockFullName) - val methodAsts = withOrder(typ.getSootClass.getMethods.asScala.toList.sortWith((x, y) => - x.getName > y.getName - )) { - (m, order) => - astForMethod(m, typ, order) + import AstCreator.* + + private val logger = LoggerFactory.getLogger(classOf[AstCreationPass]) + private val unitToAsts = mutable.HashMap[soot.Unit, Seq[Ast]]() + private val controlTargets = mutable.HashMap[Seq[Ast], soot.Unit]() + // There are many, but the popular ones should do https://en.wikipedia.org/wiki/List_of_JVM_languages + private val JVM_LANGS = HashSet("scala", "clojure", "groovy", "kotlin", "jython", "jruby") + + /** Add `typeName` to a global map and return it. The map is later passed to a pass that creates + * TYPE nodes for each key in the map. + */ + private def registerType(typeName: String): String = + global.usedTypes.put(typeName, true) + typeName + + /** Entry point of AST creation. Translates a compilation unit created by JavaParser into a + * DiffGraph containing the corresponding CPG AST. + */ + def createAst(): DiffGraphBuilder = + val astRoot = astForCompilationUnit(cls) + storeInDiffGraph(astRoot, diffGraph) + diffGraph + + /** Translate compilation unit into AST + */ + private def astForCompilationUnit(cls: SootClass): Ast = + val ast = astForPackageDeclaration(cls.getPackageName) + val namespaceBlockFullName = + ast.root.collect { case x: NewNamespaceBlock => x.fullName }.getOrElse("none") + ast.withChild(astForTypeDecl(cls.getType, namespaceBlockFullName)) + + /** Translate package declaration into AST consisting of a corresponding namespace block. + */ + private def astForPackageDeclaration(packageDecl: String): Ast = + val absolutePath = new java.io.File(filename).toPath.toAbsolutePath.normalize().toString + val name = packageDecl.split("\\.").lastOption.getOrElse("") + val namespaceBlock = NewNamespaceBlock() + .name(name) + .fullName(packageDecl) + Ast(namespaceBlock.filename(absolutePath).order(1)) + + /** Creates a list of all inherited classes and implemented interfaces. If there are none then a + * list with a single element 'java.lang.Object' is returned by default. Returns two lists in the + * form of (List[Super Classes], List[Interfaces]). + */ + private def inheritedAndImplementedClasses(clazz: SootClass): (List[String], List[String]) = + val implementsTypeFullName = clazz.getInterfaces.asScala.map { (i: SootClass) => + registerType(i.getType.toQuotedString) + }.toList + val inheritsFromTypeFullName = + if clazz.hasSuperclass && clazz.getSuperclass.getType.toQuotedString != "java.lang.Object" + then + List(registerType(clazz.getSuperclass.getType.toQuotedString)) + else if implementsTypeFullName.isEmpty then + List(registerType("java.lang.Object")) + else List() + + (inheritsFromTypeFullName, implementsTypeFullName) + + /** Creates the AST root for type declarations and acts as the entry point for method generation. + */ + private def astForTypeDecl(typ: RefType, namespaceBlockFullName: String): Ast = + val fullName = registerType(typ.toQuotedString) + val shortName = typ.getSootClass.getShortJavaStyleName + val clz = typ.getSootClass + val code = new mutable.StringBuilder() + + if clz.isPublic then code.append("public ") + else if clz.isPrivate then code.append("private ") + if clz.isStatic then code.append("static ") + if clz.isFinal then code.append("final ") + if clz.isInterface then code.append("interface ") + else if clz.isAbstract then code.append("abstract ") + if clz.isEnum then code.append("enum ") + if !clz.isInterface then code.append(s"class $shortName") + else code.append(shortName) + + val modifiers = astsForModifiers(clz) + val (inherited, implemented) = inheritedAndImplementedClasses(typ.getSootClass) + + if inherited.nonEmpty then code.append(s" extends ${inherited.mkString(", ")}") + if implemented.nonEmpty then code.append(s" implements ${implemented.mkString(", ")}") + + val typeDecl = NewTypeDecl() + .name(shortName) + .fullName(fullName) + .order(1) // Jimple always has 1 class per file + .filename(filename) + .code(code.toString()) + .inheritsFromTypeFullName(inherited ++ implemented) + .astParentType(NodeTypes.NAMESPACE_BLOCK) + .astParentFullName(namespaceBlockFullName) + val methodAsts = withOrder(typ.getSootClass.getMethods.asScala.toList.sortWith((x, y) => + x.getName > y.getName + )) { + (m, order) => + astForMethod(m, typ, order) + } + + val memberAsts = typ.getSootClass.getFields.asScala + .filter(_.isDeclared) + .zipWithIndex + .map { case (v, i) => + astForField(v, i + methodAsts.size + 1) } - - val memberAsts = typ.getSootClass.getFields.asScala - .filter(_.isDeclared) - .zipWithIndex - .map { case (v, i) => - astForField(v, i + methodAsts.size + 1) + .toList + + Ast(typeDecl) + .withChildren(astsForHostTags(clz)) + .withChildren(memberAsts) + .withChildren(methodAsts) + .withChildren(modifiers) + end astForTypeDecl + + private def astForField(field: SootField, order: Int): Ast = + val typeFullName = registerType(field.getType.toQuotedString) + val name = field.getName + val code = if field.getDeclaration.contains("enum") then name else s"$typeFullName $name" + val annotations = field.getTags.asScala + .collect { case x: VisibilityAnnotationTag => x } + .flatMap(_.getAnnotations.asScala) + + Ast( + NewMember() + .name(name) + .lineNumber(line(field)) + .columnNumber(column(field)) + .typeFullName(typeFullName) + .order(order) + .code(code) + ).withChildren(withOrder(annotations) { (a, aOrder) => + astsForAnnotations(a, aOrder, field) + }) + end astForField + + private def astForMethod(methodDeclaration: SootMethod, typeDecl: RefType, childNum: Int): Ast = + val methodNode = createMethodNode(methodDeclaration, typeDecl, childNum) + try + if !methodDeclaration.isConcrete then + // Soot is not able to parse origin parameter names of abstract methods + // https://github.com/soot-oss/soot/issues/1517 + val locals = methodDeclaration.getParameterTypes.asScala.zipWithIndex + .map { case (typ, index) => new JimpleLocal(s"param${index + 1}", typ) } + val parameterAsts = + Seq(createThisNode(methodDeclaration, NewMethodParameterIn())) ++ + withOrder(locals) { (p, order) => + astForParameter(p, order, methodDeclaration, Map()) + } + Ast(methodNode) + .withChildren(astsForModifiers(methodDeclaration)) + .withChildren(parameterAsts) + .withChildren(astsForHostTags(methodDeclaration)) + .withChild(Ast(NewBlock())) + .withChild(astForMethodReturn(methodDeclaration)) + else + val lastOrder = 2 + methodDeclaration.getParameterCount + // Map params to their annotations + val mTags = methodDeclaration.getTags.asScala + val paramAnnos = + mTags.collect { case x: VisibilityParameterAnnotationTag => x }.flatMap( + _.getVisibilityAnnotations.asScala + ) + val paramNames = + mTags.collect { case x: ParamNamesTag => x }.flatMap(_.getNames.asScala) + val parameterAnnotations = paramNames.zip(paramAnnos).filter(_._2 != null).toMap + val methodBody = Try(methodDeclaration.getActiveBody) match + case Failure(_) => methodDeclaration.retrieveActiveBody() + case Success(body) => body + val parameterAsts = + Seq(createThisNode(methodDeclaration, NewMethodParameterIn())) ++ withOrder( + methodBody.getParameterLocals + ) { + (p, order) => + astForParameter(p, order, methodDeclaration, parameterAnnotations) } - .toList - - Ast(typeDecl) - .withChildren(astsForHostTags(clz)) - .withChildren(memberAsts) - .withChildren(methodAsts) - .withChildren(modifiers) - end astForTypeDecl - - private def astForField(field: SootField, order: Int): Ast = - val typeFullName = registerType(field.getType.toQuotedString) - val name = field.getName - val code = if field.getDeclaration.contains("enum") then name else s"$typeFullName $name" - val annotations = field.getTags.asScala - .collect { case x: VisibilityAnnotationTag => x } - .flatMap(_.getAnnotations.asScala) - Ast( - NewMember() - .name(name) - .lineNumber(line(field)) - .columnNumber(column(field)) - .typeFullName(typeFullName) - .order(order) - .code(code) - ).withChildren(withOrder(annotations) { (a, aOrder) => - astsForAnnotations(a, aOrder, field) - }) - end astForField - - private def astForMethod(methodDeclaration: SootMethod, typeDecl: RefType, childNum: Int): Ast = - val methodNode = createMethodNode(methodDeclaration, typeDecl, childNum) - try - if !methodDeclaration.isConcrete then - // Soot is not able to parse origin parameter names of abstract methods - // https://github.com/soot-oss/soot/issues/1517 - val locals = methodDeclaration.getParameterTypes.asScala.zipWithIndex - .map { case (typ, index) => new JimpleLocal(s"param${index + 1}", typ) } - val parameterAsts = - Seq(createThisNode(methodDeclaration, NewMethodParameterIn())) ++ - withOrder(locals) { (p, order) => - astForParameter(p, order, methodDeclaration, Map()) - } - Ast(methodNode) - .withChildren(astsForModifiers(methodDeclaration)) - .withChildren(parameterAsts) - .withChildren(astsForHostTags(methodDeclaration)) - .withChild(Ast(NewBlock())) - .withChild(astForMethodReturn(methodDeclaration)) - else - val lastOrder = 2 + methodDeclaration.getParameterCount - // Map params to their annotations - val mTags = methodDeclaration.getTags.asScala - val paramAnnos = - mTags.collect { case x: VisibilityParameterAnnotationTag => x }.flatMap( - _.getVisibilityAnnotations.asScala - ) - val paramNames = - mTags.collect { case x: ParamNamesTag => x }.flatMap(_.getNames.asScala) - val parameterAnnotations = paramNames.zip(paramAnnos).filter(_._2 != null).toMap - val methodBody = Try(methodDeclaration.getActiveBody) match - case Failure(_) => methodDeclaration.retrieveActiveBody() - case Success(body) => body - val parameterAsts = - Seq(createThisNode(methodDeclaration, NewMethodParameterIn())) ++ withOrder( - methodBody.getParameterLocals - ) { - (p, order) => - astForParameter(p, order, methodDeclaration, parameterAnnotations) - } - Ast( - methodNode - .lineNumberEnd(methodBody.toString.split('\n').filterNot(_.isBlank).length) - .code(methodBody.toString) - ) - .withChildren(astsForModifiers(methodDeclaration)) - .withChildren(parameterAsts) - .withChildren(astsForHostTags(methodDeclaration)) - .withChild(astForMethodBody(methodBody, lastOrder)) - .withChild(astForMethodReturn(methodDeclaration)) - catch - case e: RuntimeException => - // Use a few heuristics to determine if this is not built with the JDK - val nonJavaLibs = - cls.getInterfaces.asScala.map(_.getPackageName).filter(JVM_LANGS.contains).toSet - if nonJavaLibs.nonEmpty || cls.getMethods.asScala.exists(_.getName.endsWith("$")) - then - val errMsg = - "The bytecode for this method suggests it is built with a non-Java JVM language. " + - "Soot requires including the specific language's SDK in the analysis to create the method body for " + - s"'${methodNode.fullName}' correctly." - logger.warn( - if nonJavaLibs.nonEmpty then - s"$errMsg. Language(s) detected: ${nonJavaLibs.mkString(",")}." - else errMsg - ) - else - logger.warn( - s"Unexpected runtime exception while parsing method body! Will stub the method '${methodNode.fullName}''", - e - ) - Ast(methodNode) - .withChildren(astsForModifiers(methodDeclaration)) - .withChildren(astsForHostTags(methodDeclaration)) - .withChild(astForMethodReturn(methodDeclaration)) - finally - // Join all targets with CFG edges - this seems to work from what is seen on DotFiles - controlTargets.foreach({ case (asts, units) => - asts.headOption match - case Some(value) => - diffGraph.addEdge( - value.root.get, - unitToAsts(units).last.root.get, - EdgeTypes.CFG - ) - case None => - }) - // Clear these maps - controlTargets.clear() - unitToAsts.clear() - end try - end astForMethod - - private def getEvaluationStrategy(typ: soot.Type): String = - typ match - case _: PrimType => EvaluationStrategies.BY_VALUE - case _: VoidType => EvaluationStrategies.BY_VALUE - case _: NullType => EvaluationStrategies.BY_VALUE - case _: RefLikeType => EvaluationStrategies.BY_REFERENCE - case _ => EvaluationStrategies.BY_SHARING - - private def astForParameter( - parameter: soot.Local, - childNum: Int, - methodDeclaration: SootMethod, - parameterAnnotations: Map[String, VisibilityAnnotationTag] - ): Ast = - val typeFullName = registerType(parameter.getType.toQuotedString) - - val parameterNode = Ast( - NewMethodParameterIn() - .name(parameter.getName) - .code(s"$typeFullName ${parameter.getName}") - .typeFullName(typeFullName) - .order(childNum) - .index(childNum) - .lineNumber(line(methodDeclaration)) - .columnNumber(column(methodDeclaration)) - .evaluationStrategy(getEvaluationStrategy(parameter.getType)) + methodNode + .lineNumberEnd(methodBody.toString.split('\n').filterNot(_.isBlank).length) + .code(methodBody.toString) ) - - parameterAnnotations.get(parameter.getName) match - case Some(annoRoot) => - parameterNode.withChildren(withOrder(annoRoot.getAnnotations.asScala) { - (a, order) => - astsForAnnotations(a, order, methodDeclaration) - }) - case None => parameterNode - end astForParameter - - private def astsForHostTags(host: AbstractHost): Seq[Ast] = - host.getTags.asScala - .collect { case x: VisibilityAnnotationTag => x } - .flatMap { x => - withOrder(x.getAnnotations.asScala) { (a, order) => - astsForAnnotations(a, order, host) - } - } - .toSeq - - private def astsForAnnotations(annotation: AnnotationTag, order: Int, host: AbstractHost): Ast = - val annoType = registerType(annotation.getType.parseAsJavaType) - val name = annoType.split('.').last - val elementNodes = withOrder(annotation.getElems.asScala) { case (a, order) => - astForAnnotationElement(a, order, host) - } - val annotationNode = NewAnnotation() - .name(name) - .code( - s"@$name(${elementNodes.flatMap(_.root).flatMap(_.properties.get(PropertyNames.CODE)).mkString(", ")})" + .withChildren(astsForModifiers(methodDeclaration)) + .withChildren(parameterAsts) + .withChildren(astsForHostTags(methodDeclaration)) + .withChild(astForMethodBody(methodBody, lastOrder)) + .withChild(astForMethodReturn(methodDeclaration)) + catch + case e: RuntimeException => + // Use a few heuristics to determine if this is not built with the JDK + val nonJavaLibs = + cls.getInterfaces.asScala.map(_.getPackageName).filter(JVM_LANGS.contains).toSet + if nonJavaLibs.nonEmpty || cls.getMethods.asScala.exists(_.getName.endsWith("$")) + then + val errMsg = + "The bytecode for this method suggests it is built with a non-Java JVM language. " + + "Soot requires including the specific language's SDK in the analysis to create the method body for " + + s"'${methodNode.fullName}' correctly." + logger.warn( + if nonJavaLibs.nonEmpty then + s"$errMsg. Language(s) detected: ${nonJavaLibs.mkString(",")}." + else errMsg ) - .fullName(annoType) - .order(order) - Ast(annotationNode) - .withChildren(elementNodes) - - private def astForAnnotationElement( - annoElement: AnnotationElem, - order: Int, - parent: AbstractHost - ): Ast = - def getLiteralElementNameAndCode(annoElement: AnnotationElem): (String, String) = - annoElement match - case x: AnnotationClassElem => - val desc = registerType(x.getDesc.parseAsJavaType) - (desc, desc) - case x: AnnotationBooleanElem => (x.getValue.toString, x.getValue.toString) - case x: AnnotationDoubleElem => (x.getValue.toString, x.getValue.toString) - case x: AnnotationEnumElem => (x.getConstantName, x.getConstantName) - case x: AnnotationFloatElem => (x.getValue.toString, x.getValue.toString) - case x: AnnotationIntElem => (x.getValue.toString, x.getValue.toString) - case x: AnnotationLongElem => (x.getValue.toString, x.getValue.toString) - case _ => ("", "") - val lineNo = line(parent) - val columnNo = column(parent) - val codeBuilder = new mutable.StringBuilder() - val astChildren = ListBuffer.empty[Ast] - if annoElement.getName != null then - astChildren.append( - Ast( - NewAnnotationParameter() - .code(annoElement.getName) - .lineNumber(lineNo) - .columnNumber(columnNo) - .order(1) - ) + else + logger.warn( + s"Unexpected runtime exception while parsing method body! Will stub the method '${methodNode.fullName}''", + e ) - codeBuilder.append(s"${annoElement.getName} = ") - astChildren.append(annoElement match - case x: AnnotationAnnotationElem => - val rhsAst = astsForAnnotations(x.getValue, astChildren.size + 1, parent) - codeBuilder.append( - s"${rhsAst.root.flatMap(_.properties.get(PropertyNames.CODE)).mkString(", ")}" - ) - rhsAst - case x: AnnotationArrayElem => - val (rhsAst, code) = astForAnnotationArrayElement(x, astChildren.size + 1, parent) - codeBuilder.append(code) - rhsAst - case x => - val (name, code) = x match - case y: AnnotationStringElem => (y.getValue, s"\"${y.getValue}\"") - case _ => getLiteralElementNameAndCode(x) - val rhsOrder = if annoElement.getName == null then order else astChildren.size + 1 - codeBuilder.append(code) - Ast(NewAnnotationLiteral().name(name).code(code).order(rhsOrder).argumentIndex( - rhsOrder - )) - ) - - if astChildren.size == 1 then - astChildren.head - else - val paramAssign = NewAnnotationParameterAssign() - .code(codeBuilder.toString) - .lineNumber(lineNo) - .columnNumber(columnNo) - .order(order) - - Ast(paramAssign) - .withChildren(astChildren) - end astForAnnotationElement - - private def astForAnnotationArrayElement( - x: AnnotationArrayElem, - order: Int, - parent: AbstractHost - ): (Ast, String) = - val elems = withOrder(x.getValues.asScala) { (elem, order) => - astForAnnotationElement(elem, order, parent) - } - val code = - s"{${elems.flatMap(_.root).flatMap(_.properties.get(PropertyNames.CODE)).mkString(", ")}}" - val array = NewArrayInitializer().code(code).order(order).argumentIndex(order) - (Ast(array).withChildren(elems), code) - - private def astForMethodBody(body: Body, order: Int): Ast = - val block = NewBlock().order(order).lineNumber(line(body)).columnNumber(column(body)) - val jimpleParams = body.getParameterLocals.asScala.toList - // Don't let parameters also become locals (avoiding duplication) - val jimpleLocals = body.getLocals.asScala.filterNot(l => - jimpleParams.contains(l) || l.getName == "this" - ).toList - val locals = withOrder(jimpleLocals) { case (l, order) => - val name = l.getName - val typeFullName = registerType(l.getType.toQuotedString) - val code = s"$typeFullName $name" - Ast(NewLocal().name(name).code(code).typeFullName(typeFullName).order(order)) - } - Ast(block) - .withChildren(locals) - .withChildren(withOrder(body.getUnits.asScala.filterNot(isIgnoredUnit)) { (x, order) => - astsForStatement(x, order + locals.size) - }.flatten) - - private def isIgnoredUnit(unit: soot.Unit): Boolean = - unit match - case _: IdentityStmt => true - case _: NopStmt => true - case _ => false - - private def astsForStatement(statement: soot.Unit, order: Int): Seq[Ast] = - val stmt = statement match - case x: AssignStmt => astsForDefinition(x, order) - case x: InvokeStmt => astsForExpression(x.getInvokeExpr, order, statement) - case x: ReturnStmt => astsForReturnNode(x, order) - case x: ReturnVoidStmt => astsForReturnVoidNode(x, order) - case x: IfStmt => astsForIfStmt(x, order) - case x: GotoStmt => astsForGotoStmt(x, order) - case x: LookupSwitchStmt => astsForLookupSwitchStmt(x, order) - case x: TableSwitchStmt => astsForTableSwitchStmt(x, order) - case x: ThrowStmt => astsForThrowStmt(x, order) - case x: MonitorStmt => astsForMonitorStmt(x, order) - case _: IdentityStmt => Seq() // Identity statements redefine parameters as locals - case _: NopStmt => Seq() // Ignore NOP statements - case x => - logger.warn(s"Unhandled soot.Unit type ${x.getClass}") - Seq(astForUnknownStmt(x, None, order)) - unitToAsts.put(statement, stmt) - stmt - end astsForStatement - - private def astForBinOpExpr(binOp: BinopExpr, order: Int, parentUnit: soot.Unit): Ast = - // https://javadoc.io/static/org.soot-oss/soot/4.3.0/soot/jimple/BinopExpr.html - val operatorName = binOp match - case _: AddExpr => Operators.addition - case _: SubExpr => Operators.subtraction - case _: MulExpr => Operators.multiplication - case _: DivExpr => Operators.division - case _: RemExpr => Operators.modulo - case _: GeExpr => Operators.greaterEqualsThan - case _: GtExpr => Operators.greaterThan - case _: LeExpr => Operators.lessEqualsThan - case _: LtExpr => Operators.lessThan - case _: ShlExpr => Operators.shiftLeft - case _: ShrExpr => Operators.logicalShiftRight - case _: UshrExpr => Operators.arithmeticShiftRight - case _: CmpExpr => Operators.compare - case _: CmpgExpr => Operators.compare - case _: CmplExpr => Operators.compare - case _: AndExpr => Operators.and - case _: OrExpr => Operators.or - case _: XorExpr => Operators.xor - case _: EqExpr => Operators.equals - case _: NeExpr => Operators.notEquals - case _ => - logger.warn( - s"Unhandled binary operator ${binOp.getSymbol} (${binOp.getClass}). This is unexpected behaviour." + Ast(methodNode) + .withChildren(astsForModifiers(methodDeclaration)) + .withChildren(astsForHostTags(methodDeclaration)) + .withChild(astForMethodReturn(methodDeclaration)) + finally + // Join all targets with CFG edges - this seems to work from what is seen on DotFiles + controlTargets.foreach({ case (asts, units) => + asts.headOption match + case Some(value) => + diffGraph.addEdge( + value.root.get, + unitToAsts(units).last.root.get, + EdgeTypes.CFG ) - ".unknown" - - val callNode = NewCall() - .name(operatorName) - .methodFullName(operatorName) - .dispatchType(DispatchTypes.STATIC_DISPATCH) - .code(binOp.toString) - .argumentIndex(order) - .order(order) - - val args = - astsForValue(binOp.getOp1, 1, parentUnit) ++ astsForValue(binOp.getOp2, 2, parentUnit) - callAst(callNode, args) - end astForBinOpExpr - - private def astsForExpression(expr: Expr, order: Int, parentUnit: soot.Unit): Seq[Ast] = - expr match - case x: BinopExpr => Seq(astForBinOpExpr(x, order, parentUnit)) - case x: InvokeExpr => Seq(astForInvokeExpr(x, order, parentUnit)) - case x: AnyNewExpr => Seq(astForNewExpr(x, order, parentUnit)) - case x: CastExpr => Seq(astForUnaryExpr(Operators.cast, x, x.getOp, order, parentUnit)) - case x: InstanceOfExpr => - Seq(astForUnaryExpr(Operators.instanceOf, x, x.getOp, order, parentUnit)) - case x: LengthExpr => - Seq(astForUnaryExpr(Operators.lengthOf, x, x.getOp, order, parentUnit)) - case x: NegExpr => Seq(astForUnaryExpr(Operators.minus, x, x.getOp, order, parentUnit)) - case x => - logger.warn(s"Unhandled soot.Expr type ${x.getClass}") - Seq() - - private def astsForValue(value: soot.Value, order: Int, parentUnit: soot.Unit): Seq[Ast] = - value match - case x: Expr => astsForExpression(x, order, parentUnit) - case x: soot.Local => Seq(astForLocal(x, order, parentUnit)) - case x: CaughtExceptionRef => Seq(astForCaughtExceptionRef(x, order, parentUnit)) - case x: Constant => Seq(astForConstantExpr(x, order)) - case x: FieldRef => Seq(astForFieldRef(x, order, parentUnit)) - case x: ThisRef => Seq(createThisNode(x)) - case x: ParameterRef => Seq(createParameterNode(x, order)) - case x: IdentityRef => Seq(astForIdentityRef(x, order, parentUnit)) - case x: ArrayRef => Seq(astForArrayRef(x, order, parentUnit)) - case x => - logger.warn(s"Unhandled soot.Value type ${x.getClass}") - Seq() - - private def astForArrayRef(arrRef: ArrayRef, order: Int, parentUnit: soot.Unit): Ast = - val indexAccess = NewCall() - .name(Operators.indexAccess) - .methodFullName(Operators.indexAccess) - .dispatchType(DispatchTypes.STATIC_DISPATCH) - .code(arrRef.toString()) - .order(order) - .argumentIndex(order) - .lineNumber(line(parentUnit)) - .columnNumber(column(parentUnit)) - .typeFullName(registerType(arrRef.getType.toQuotedString)) - - val astChildren = astsForValue(arrRef.getBase, 1, parentUnit) ++ astsForValue( - arrRef.getIndex, - 2, - parentUnit - ) - Ast(indexAccess) - .withChildren(astChildren) - .withArgEdges(indexAccess, astChildren.flatMap(_.root)) - end astForArrayRef - - private def astForLocal(local: soot.Local, order: Int, parentUnit: soot.Unit): Ast = - val name = local.getName - val typeFullName = registerType(local.getType.toQuotedString) - Ast( - NewIdentifier() - .name(name) - .lineNumber(line(parentUnit)) - .columnNumber(column(parentUnit)) - .order(order) - .argumentIndex(order) - .code(name) - .typeFullName(typeFullName) + case None => + }) + // Clear these maps + controlTargets.clear() + unitToAsts.clear() + end try + end astForMethod + + private def getEvaluationStrategy(typ: soot.Type): String = + typ match + case _: PrimType => EvaluationStrategies.BY_VALUE + case _: VoidType => EvaluationStrategies.BY_VALUE + case _: NullType => EvaluationStrategies.BY_VALUE + case _: RefLikeType => EvaluationStrategies.BY_REFERENCE + case _ => EvaluationStrategies.BY_SHARING + + private def astForParameter( + parameter: soot.Local, + childNum: Int, + methodDeclaration: SootMethod, + parameterAnnotations: Map[String, VisibilityAnnotationTag] + ): Ast = + val typeFullName = registerType(parameter.getType.toQuotedString) + + val parameterNode = Ast( + NewMethodParameterIn() + .name(parameter.getName) + .code(s"$typeFullName ${parameter.getName}") + .typeFullName(typeFullName) + .order(childNum) + .index(childNum) + .lineNumber(line(methodDeclaration)) + .columnNumber(column(methodDeclaration)) + .evaluationStrategy(getEvaluationStrategy(parameter.getType)) + ) + + parameterAnnotations.get(parameter.getName) match + case Some(annoRoot) => + parameterNode.withChildren(withOrder(annoRoot.getAnnotations.asScala) { + (a, order) => + astsForAnnotations(a, order, methodDeclaration) + }) + case None => parameterNode + end astForParameter + + private def astsForHostTags(host: AbstractHost): Seq[Ast] = + host.getTags.asScala + .collect { case x: VisibilityAnnotationTag => x } + .flatMap { x => + withOrder(x.getAnnotations.asScala) { (a, order) => + astsForAnnotations(a, order, host) + } + } + .toSeq + + private def astsForAnnotations(annotation: AnnotationTag, order: Int, host: AbstractHost): Ast = + val annoType = registerType(annotation.getType.parseAsJavaType) + val name = annoType.split('.').last + val elementNodes = withOrder(annotation.getElems.asScala) { case (a, order) => + astForAnnotationElement(a, order, host) + } + val annotationNode = NewAnnotation() + .name(name) + .code( + s"@$name(${elementNodes.flatMap(_.root).flatMap(_.properties.get(PropertyNames.CODE)).mkString(", ")})" ) - - private def astForIdentityRef(x: IdentityRef, order: Int, parentUnit: soot.Unit): Ast = + .fullName(annoType) + .order(order) + Ast(annotationNode) + .withChildren(elementNodes) + + private def astForAnnotationElement( + annoElement: AnnotationElem, + order: Int, + parent: AbstractHost + ): Ast = + def getLiteralElementNameAndCode(annoElement: AnnotationElem): (String, String) = + annoElement match + case x: AnnotationClassElem => + val desc = registerType(x.getDesc.parseAsJavaType) + (desc, desc) + case x: AnnotationBooleanElem => (x.getValue.toString, x.getValue.toString) + case x: AnnotationDoubleElem => (x.getValue.toString, x.getValue.toString) + case x: AnnotationEnumElem => (x.getConstantName, x.getConstantName) + case x: AnnotationFloatElem => (x.getValue.toString, x.getValue.toString) + case x: AnnotationIntElem => (x.getValue.toString, x.getValue.toString) + case x: AnnotationLongElem => (x.getValue.toString, x.getValue.toString) + case _ => ("", "") + val lineNo = line(parent) + val columnNo = column(parent) + val codeBuilder = new mutable.StringBuilder() + val astChildren = ListBuffer.empty[Ast] + if annoElement.getName != null then + astChildren.append( Ast( - NewIdentifier() - .code(x.toString()) - .name(x.toString()) - .order(order) - .argumentIndex(order) - .typeFullName(registerType(x.getType.toQuotedString)) - .lineNumber(line(parentUnit)) - .columnNumber(column(parentUnit)) + NewAnnotationParameter() + .code(annoElement.getName) + .lineNumber(lineNo) + .columnNumber(columnNo) + .order(1) ) - - private def astForInvokeExpr(invokeExpr: InvokeExpr, order: Int, parentUnit: soot.Unit): Ast = - val callee = invokeExpr.getMethodRef - val dispatchType = invokeExpr match - case _ if callee.isConstructor => DispatchTypes.STATIC_DISPATCH - case _: DynamicInvokeExpr => DispatchTypes.DYNAMIC_DISPATCH - case _: InstanceInvokeExpr => DispatchTypes.DYNAMIC_DISPATCH - case _ => DispatchTypes.STATIC_DISPATCH - - val signature = - s"${registerType(callee.getReturnType.toQuotedString)}(${(for (i <- 0 until callee.getParameterTypes.size()) - yield registerType(callee.getParameterType(i).toQuotedString)).mkString(",")})" - val thisAsts = invokeExpr match - case expr: InstanceInvokeExpr => astsForValue(expr.getBase, 0, parentUnit) - case _ => Seq(createThisNode(callee, NewIdentifier())) - - val methodName = - if callee.isConstructor then - registerType(callee.getDeclaringClass.getType.getClassName) - else - callee.getName - - val calleeType = registerType(callee.getDeclaringClass.getType.toQuotedString) - val callType = - if callee.isConstructor then "void" - else calleeType - - val code = invokeExpr match - case expr: InstanceInvokeExpr => - s"${expr.getBase}.$methodName(${invokeExpr.getArgs.asScala.mkString(", ")})" - case _ => s"$methodName(${invokeExpr.getArgs.asScala.mkString(", ")})" - - val callNode = NewCall() - .name(callee.getName) - .code(code) - .dispatchType(dispatchType) - .order(order) - .argumentIndex(order) - .methodFullName(s"$calleeType.${callee.getName}:$signature") - .signature(signature) - .typeFullName(callType) - .lineNumber(line(parentUnit)) - .columnNumber(column(parentUnit)) - - val argAsts = withOrder(invokeExpr match - case x: DynamicInvokeExpr => x.getArgs.asScala ++ x.getBootstrapArgs.asScala - case x => x.getArgs.asScala - ) { case (arg, order) => - astsForValue(arg, order, parentUnit) - }.flatten - - val callAst = Ast(callNode) - .withChildren(thisAsts) - .withChildren(argAsts) - .withArgEdges(callNode, thisAsts.flatMap(_.root)) - .withArgEdges(callNode, argAsts.flatMap(_.root)) - - thisAsts.flatMap(_.root).headOption match - case Some(thisAst) => callAst.withReceiverEdge(callNode, thisAst) - case None => callAst - end astForInvokeExpr - - private def astForNewExpr(x: AnyNewExpr, order: Int, parentUnit: soot.Unit): Ast = - x match - case u: NewArrayExpr => - astForArrayCreateExpr(x, List(u.getSize), order, parentUnit) - case u: NewMultiArrayExpr => - astForArrayCreateExpr(x, u.getSizes.asScala, order, parentUnit) - case _ => - val parentType = registerType(x.getType.toQuotedString) - Ast( - NewCall() - .name(Operators.alloc) - .methodFullName(Operators.alloc) - .typeFullName(parentType) - .code(s"new ${x.getType}") - .dispatchType(DispatchTypes.STATIC_DISPATCH) - .order(order) - .argumentIndex(order) - .lineNumber(line(parentUnit)) - .columnNumber(column(parentUnit)) - ) - - private def astForArrayCreateExpr( - arrayInitExpr: Expr, - sizes: Iterable[Value], - order: Int, - parentUnit: soot.Unit - ): Ast = - // Jimple does not have Operators.arrayInitializer - // to enforce 3 address code form - val arrayBaseType = registerType(arrayInitExpr.getType.toQuotedString) - val code = - s"new ${arrayBaseType.substring(0, arrayBaseType.indexOf('['))}${sizes.map(s => s"[$s]").mkString}" - val callBlock = NewCall() - .name(Operators.alloc) - .methodFullName(Operators.alloc) - .code(code) - .dispatchType(DispatchTypes.STATIC_DISPATCH) - .order(order) - .typeFullName(arrayBaseType) - .argumentIndex(order) - .lineNumber(line(parentUnit)) - .columnNumber(column(parentUnit)) - val valueAsts = withOrder(sizes) { (s, o) => - astsForValue(s, o, parentUnit) - }.flatten - Ast(callBlock) - .withChildren(valueAsts) - .withArgEdges(callBlock, valueAsts.flatMap(_.root)) - end astForArrayCreateExpr - - private def astForUnaryExpr( - methodName: String, - unaryExpr: Expr, - op: Value, - order: Int, - parentUnit: soot.Unit - ): Ast = - val callBlock = NewCall() - .name(methodName) - .methodFullName(methodName) - .code(unaryExpr.toString()) - .dispatchType(DispatchTypes.STATIC_DISPATCH) + ) + codeBuilder.append(s"${annoElement.getName} = ") + astChildren.append(annoElement match + case x: AnnotationAnnotationElem => + val rhsAst = astsForAnnotations(x.getValue, astChildren.size + 1, parent) + codeBuilder.append( + s"${rhsAst.root.flatMap(_.properties.get(PropertyNames.CODE)).mkString(", ")}" + ) + rhsAst + case x: AnnotationArrayElem => + val (rhsAst, code) = astForAnnotationArrayElement(x, astChildren.size + 1, parent) + codeBuilder.append(code) + rhsAst + case x => + val (name, code) = x match + case y: AnnotationStringElem => (y.getValue, s"\"${y.getValue}\"") + case _ => getLiteralElementNameAndCode(x) + val rhsOrder = if annoElement.getName == null then order else astChildren.size + 1 + codeBuilder.append(code) + Ast(NewAnnotationLiteral().name(name).code(code).order(rhsOrder).argumentIndex( + rhsOrder + )) + ) + + if astChildren.size == 1 then + astChildren.head + else + val paramAssign = NewAnnotationParameterAssign() + .code(codeBuilder.toString) + .lineNumber(lineNo) + .columnNumber(columnNo) + .order(order) + + Ast(paramAssign) + .withChildren(astChildren) + end astForAnnotationElement + + private def astForAnnotationArrayElement( + x: AnnotationArrayElem, + order: Int, + parent: AbstractHost + ): (Ast, String) = + val elems = withOrder(x.getValues.asScala) { (elem, order) => + astForAnnotationElement(elem, order, parent) + } + val code = + s"{${elems.flatMap(_.root).flatMap(_.properties.get(PropertyNames.CODE)).mkString(", ")}}" + val array = NewArrayInitializer().code(code).order(order).argumentIndex(order) + (Ast(array).withChildren(elems), code) + + private def astForMethodBody(body: Body, order: Int): Ast = + val block = NewBlock().order(order).lineNumber(line(body)).columnNumber(column(body)) + val jimpleParams = body.getParameterLocals.asScala.toList + // Don't let parameters also become locals (avoiding duplication) + val jimpleLocals = body.getLocals.asScala.filterNot(l => + jimpleParams.contains(l) || l.getName == "this" + ).toList + val locals = withOrder(jimpleLocals) { case (l, order) => + val name = l.getName + val typeFullName = registerType(l.getType.toQuotedString) + val code = s"$typeFullName $name" + Ast(NewLocal().name(name).code(code).typeFullName(typeFullName).order(order)) + } + Ast(block) + .withChildren(locals) + .withChildren(withOrder(body.getUnits.asScala.filterNot(isIgnoredUnit)) { (x, order) => + astsForStatement(x, order + locals.size) + }.flatten) + + private def isIgnoredUnit(unit: soot.Unit): Boolean = + unit match + case _: IdentityStmt => true + case _: NopStmt => true + case _ => false + + private def astsForStatement(statement: soot.Unit, order: Int): Seq[Ast] = + val stmt = statement match + case x: AssignStmt => astsForDefinition(x, order) + case x: InvokeStmt => astsForExpression(x.getInvokeExpr, order, statement) + case x: ReturnStmt => astsForReturnNode(x, order) + case x: ReturnVoidStmt => astsForReturnVoidNode(x, order) + case x: IfStmt => astsForIfStmt(x, order) + case x: GotoStmt => astsForGotoStmt(x, order) + case x: LookupSwitchStmt => astsForLookupSwitchStmt(x, order) + case x: TableSwitchStmt => astsForTableSwitchStmt(x, order) + case x: ThrowStmt => astsForThrowStmt(x, order) + case x: MonitorStmt => astsForMonitorStmt(x, order) + case _: IdentityStmt => Seq() // Identity statements redefine parameters as locals + case _: NopStmt => Seq() // Ignore NOP statements + case x => + logger.warn(s"Unhandled soot.Unit type ${x.getClass}") + Seq(astForUnknownStmt(x, None, order)) + unitToAsts.put(statement, stmt) + stmt + end astsForStatement + + private def astForBinOpExpr(binOp: BinopExpr, order: Int, parentUnit: soot.Unit): Ast = + // https://javadoc.io/static/org.soot-oss/soot/4.3.0/soot/jimple/BinopExpr.html + val operatorName = binOp match + case _: AddExpr => Operators.addition + case _: SubExpr => Operators.subtraction + case _: MulExpr => Operators.multiplication + case _: DivExpr => Operators.division + case _: RemExpr => Operators.modulo + case _: GeExpr => Operators.greaterEqualsThan + case _: GtExpr => Operators.greaterThan + case _: LeExpr => Operators.lessEqualsThan + case _: LtExpr => Operators.lessThan + case _: ShlExpr => Operators.shiftLeft + case _: ShrExpr => Operators.logicalShiftRight + case _: UshrExpr => Operators.arithmeticShiftRight + case _: CmpExpr => Operators.compare + case _: CmpgExpr => Operators.compare + case _: CmplExpr => Operators.compare + case _: AndExpr => Operators.and + case _: OrExpr => Operators.or + case _: XorExpr => Operators.xor + case _: EqExpr => Operators.equals + case _: NeExpr => Operators.notEquals + case _ => + logger.warn( + s"Unhandled binary operator ${binOp.getSymbol} (${binOp.getClass}). This is unexpected behaviour." + ) + ".unknown" + + val callNode = NewCall() + .name(operatorName) + .methodFullName(operatorName) + .dispatchType(DispatchTypes.STATIC_DISPATCH) + .code(binOp.toString) + .argumentIndex(order) + .order(order) + + val args = + astsForValue(binOp.getOp1, 1, parentUnit) ++ astsForValue(binOp.getOp2, 2, parentUnit) + callAst(callNode, args) + end astForBinOpExpr + + private def astsForExpression(expr: Expr, order: Int, parentUnit: soot.Unit): Seq[Ast] = + expr match + case x: BinopExpr => Seq(astForBinOpExpr(x, order, parentUnit)) + case x: InvokeExpr => Seq(astForInvokeExpr(x, order, parentUnit)) + case x: AnyNewExpr => Seq(astForNewExpr(x, order, parentUnit)) + case x: CastExpr => Seq(astForUnaryExpr(Operators.cast, x, x.getOp, order, parentUnit)) + case x: InstanceOfExpr => + Seq(astForUnaryExpr(Operators.instanceOf, x, x.getOp, order, parentUnit)) + case x: LengthExpr => + Seq(astForUnaryExpr(Operators.lengthOf, x, x.getOp, order, parentUnit)) + case x: NegExpr => Seq(astForUnaryExpr(Operators.minus, x, x.getOp, order, parentUnit)) + case x => + logger.warn(s"Unhandled soot.Expr type ${x.getClass}") + Seq() + + private def astsForValue(value: soot.Value, order: Int, parentUnit: soot.Unit): Seq[Ast] = + value match + case x: Expr => astsForExpression(x, order, parentUnit) + case x: soot.Local => Seq(astForLocal(x, order, parentUnit)) + case x: CaughtExceptionRef => Seq(astForCaughtExceptionRef(x, order, parentUnit)) + case x: Constant => Seq(astForConstantExpr(x, order)) + case x: FieldRef => Seq(astForFieldRef(x, order, parentUnit)) + case x: ThisRef => Seq(createThisNode(x)) + case x: ParameterRef => Seq(createParameterNode(x, order)) + case x: IdentityRef => Seq(astForIdentityRef(x, order, parentUnit)) + case x: ArrayRef => Seq(astForArrayRef(x, order, parentUnit)) + case x => + logger.warn(s"Unhandled soot.Value type ${x.getClass}") + Seq() + + private def astForArrayRef(arrRef: ArrayRef, order: Int, parentUnit: soot.Unit): Ast = + val indexAccess = NewCall() + .name(Operators.indexAccess) + .methodFullName(Operators.indexAccess) + .dispatchType(DispatchTypes.STATIC_DISPATCH) + .code(arrRef.toString()) + .order(order) + .argumentIndex(order) + .lineNumber(line(parentUnit)) + .columnNumber(column(parentUnit)) + .typeFullName(registerType(arrRef.getType.toQuotedString)) + + val astChildren = astsForValue(arrRef.getBase, 1, parentUnit) ++ astsForValue( + arrRef.getIndex, + 2, + parentUnit + ) + Ast(indexAccess) + .withChildren(astChildren) + .withArgEdges(indexAccess, astChildren.flatMap(_.root)) + end astForArrayRef + + private def astForLocal(local: soot.Local, order: Int, parentUnit: soot.Unit): Ast = + val name = local.getName + val typeFullName = registerType(local.getType.toQuotedString) + Ast( + NewIdentifier() + .name(name) + .lineNumber(line(parentUnit)) + .columnNumber(column(parentUnit)) + .order(order) + .argumentIndex(order) + .code(name) + .typeFullName(typeFullName) + ) + + private def astForIdentityRef(x: IdentityRef, order: Int, parentUnit: soot.Unit): Ast = + Ast( + NewIdentifier() + .code(x.toString()) + .name(x.toString()) .order(order) - .typeFullName(registerType(unaryExpr.getType.toQuotedString)) .argumentIndex(order) + .typeFullName(registerType(x.getType.toQuotedString)) .lineNumber(line(parentUnit)) .columnNumber(column(parentUnit)) - - def astForTypeRef(t: String, order: Int) = - Seq( - Ast( - NewTypeRef() - .code(if t.contains('.') then t.substring(t.lastIndexOf('.') + 1, t.length) - else t) - .order(order) - .argumentIndex(order) - .lineNumber(line(parentUnit)) - .columnNumber(column(parentUnit)) - .typeFullName(t) - ) - ) - - val valueAsts = unaryExpr match - case instanceOfExpr: InstanceOfExpr => - val t = registerType(instanceOfExpr.getCheckType.toQuotedString) - astsForValue(op, 1, parentUnit) ++ astForTypeRef(t, 2) - case castExpr: CastExpr => - val t = registerType(castExpr.getCastType.toQuotedString) - astForTypeRef(t, 1) ++ astsForValue(op, 2, parentUnit) - case _ => astsForValue(op, 1, parentUnit) - - Ast(callBlock) - .withChildren(valueAsts) - .withArgEdges(callBlock, valueAsts.flatMap(_.root)) - end astForUnaryExpr - - private def createThisNode(method: ThisRef): Ast = - Ast( - NewIdentifier() - .name("this") - .code("this") - .typeFullName(registerType(method.getType.toQuotedString)) - .dynamicTypeHintFullName(Seq(registerType(method.getType.toQuotedString))) - .order(0) - .argumentIndex(0) - ) - - private def createThisNode(method: SootMethod, builder: NewNode): Ast = - createThisNode(method.makeRef(), builder) - - private def createThisNode(method: SootMethodRef, builder: NewNode): Ast = - if !method.isStatic || method.isConstructor then - val parentType = - registerType(Try(method.getDeclaringClass.getType.toQuotedString).getOrElse("ANY")) - Ast(builder match - case x: NewIdentifier => - x.name("this") - .code("this") - .typeFullName(parentType) - .order(0) - .argumentIndex(0) - .dynamicTypeHintFullName(Seq(parentType)) - case _: NewMethodParameterIn => - NodeBuilders.newThisParameterNode( - typeFullName = parentType, - dynamicTypeHintFullName = Seq(parentType), - line = line(Try(method.tryResolve()).getOrElse(null)) - ) - case x => x - ) + ) + + private def astForInvokeExpr(invokeExpr: InvokeExpr, order: Int, parentUnit: soot.Unit): Ast = + val callee = invokeExpr.getMethodRef + val dispatchType = invokeExpr match + case _ if callee.isConstructor => DispatchTypes.STATIC_DISPATCH + case _: DynamicInvokeExpr => DispatchTypes.DYNAMIC_DISPATCH + case _: InstanceInvokeExpr => DispatchTypes.DYNAMIC_DISPATCH + case _ => DispatchTypes.STATIC_DISPATCH + + val signature = + s"${registerType(callee.getReturnType.toQuotedString)}(${(for (i <- 0 until callee.getParameterTypes.size()) + yield registerType(callee.getParameterType(i).toQuotedString)).mkString(",")})" + val thisAsts = invokeExpr match + case expr: InstanceInvokeExpr => astsForValue(expr.getBase, 0, parentUnit) + case _ => Seq(createThisNode(callee, NewIdentifier())) + + val methodName = + if callee.isConstructor then + registerType(callee.getDeclaringClass.getType.getClassName) else - Ast() - - private def createParameterNode(parameterRef: ParameterRef, order: Int): Ast = - val name = s"@parameter${parameterRef.getIndex}" - Ast( - NewIdentifier() - .name(name) - .code(name) - .typeFullName(registerType(parameterRef.getType.toQuotedString)) - .order(order) - .argumentIndex(order) - ) + callee.getName + + val calleeType = registerType(callee.getDeclaringClass.getType.toQuotedString) + val callType = + if callee.isConstructor then "void" + else calleeType + + val code = invokeExpr match + case expr: InstanceInvokeExpr => + s"${expr.getBase}.$methodName(${invokeExpr.getArgs.asScala.mkString(", ")})" + case _ => s"$methodName(${invokeExpr.getArgs.asScala.mkString(", ")})" + + val callNode = NewCall() + .name(callee.getName) + .code(code) + .dispatchType(dispatchType) + .order(order) + .argumentIndex(order) + .methodFullName(s"$calleeType.${callee.getName}:$signature") + .signature(signature) + .typeFullName(callType) + .lineNumber(line(parentUnit)) + .columnNumber(column(parentUnit)) + + val argAsts = withOrder(invokeExpr match + case x: DynamicInvokeExpr => x.getArgs.asScala ++ x.getBootstrapArgs.asScala + case x => x.getArgs.asScala + ) { case (arg, order) => + astsForValue(arg, order, parentUnit) + }.flatten + + val callAst = Ast(callNode) + .withChildren(thisAsts) + .withChildren(argAsts) + .withArgEdges(callNode, thisAsts.flatMap(_.root)) + .withArgEdges(callNode, argAsts.flatMap(_.root)) + + thisAsts.flatMap(_.root).headOption match + case Some(thisAst) => callAst.withReceiverEdge(callNode, thisAst) + case None => callAst + end astForInvokeExpr + + private def astForNewExpr(x: AnyNewExpr, order: Int, parentUnit: soot.Unit): Ast = + x match + case u: NewArrayExpr => + astForArrayCreateExpr(x, List(u.getSize), order, parentUnit) + case u: NewMultiArrayExpr => + astForArrayCreateExpr(x, u.getSizes.asScala, order, parentUnit) + case _ => + val parentType = registerType(x.getType.toQuotedString) + Ast( + NewCall() + .name(Operators.alloc) + .methodFullName(Operators.alloc) + .typeFullName(parentType) + .code(s"new ${x.getType}") + .dispatchType(DispatchTypes.STATIC_DISPATCH) + .order(order) + .argumentIndex(order) + .lineNumber(line(parentUnit)) + .columnNumber(column(parentUnit)) + ) - /** Creates the AST for assignment statements keeping in mind Jimple is a 3-address code - * language. - */ - private def astsForDefinition(assignStmt: DefinitionStmt, order: Int): Seq[Ast] = - val initializer = assignStmt.getRightOp - val leftOp = assignStmt.getLeftOp - - val identifier = leftOp match - case x: soot.Local => Seq(astForLocal(x, 1, assignStmt)) - case x: FieldRef => Seq(astForFieldRef(x, 1, assignStmt)) - case x => astsForValue(x, 1, assignStmt) - val lhsCode = - identifier.flatMap(_.root).flatMap(_.properties.get(PropertyNames.CODE)).mkString - - val initAsts = astsForValue(initializer, 2, assignStmt) - val rhsCode = initAsts - .flatMap(_.root) - .map(_.properties.getOrElse(PropertyNames.CODE, "")) - .mkString(", ") - - val assignment = NewCall() - .name(Operators.assignment) - .methodFullName(Operators.assignment) - .code(s"$lhsCode = $rhsCode") - .dispatchType(DispatchTypes.STATIC_DISPATCH) - .order(order) - .argumentIndex(order) - .typeFullName(registerType(assignStmt.getLeftOp.getType.toQuotedString)) - val initializerAst = Seq(callAst(assignment, identifier ++ initAsts)) - initializerAst.toList - end astsForDefinition - - private def astsForIfStmt(ifStmt: IfStmt, order: Int): Seq[Ast] = - // bytecode/jimple ASTs are flat so there will not be nested bodies - val condition = astsForValue(ifStmt.getCondition, order, ifStmt) - controlTargets.put(condition, ifStmt.getTarget) - condition - - private def astsForGotoStmt(gotoStmt: GotoStmt, order: Int): Seq[Ast] = - // bytecode/jimple ASTs are flat so there will not be nested bodies - val gotoAst = Seq( + private def astForArrayCreateExpr( + arrayInitExpr: Expr, + sizes: Iterable[Value], + order: Int, + parentUnit: soot.Unit + ): Ast = + // Jimple does not have Operators.arrayInitializer + // to enforce 3 address code form + val arrayBaseType = registerType(arrayInitExpr.getType.toQuotedString) + val code = + s"new ${arrayBaseType.substring(0, arrayBaseType.indexOf('['))}${sizes.map(s => s"[$s]").mkString}" + val callBlock = NewCall() + .name(Operators.alloc) + .methodFullName(Operators.alloc) + .code(code) + .dispatchType(DispatchTypes.STATIC_DISPATCH) + .order(order) + .typeFullName(arrayBaseType) + .argumentIndex(order) + .lineNumber(line(parentUnit)) + .columnNumber(column(parentUnit)) + val valueAsts = withOrder(sizes) { (s, o) => + astsForValue(s, o, parentUnit) + }.flatten + Ast(callBlock) + .withChildren(valueAsts) + .withArgEdges(callBlock, valueAsts.flatMap(_.root)) + end astForArrayCreateExpr + + private def astForUnaryExpr( + methodName: String, + unaryExpr: Expr, + op: Value, + order: Int, + parentUnit: soot.Unit + ): Ast = + val callBlock = NewCall() + .name(methodName) + .methodFullName(methodName) + .code(unaryExpr.toString()) + .dispatchType(DispatchTypes.STATIC_DISPATCH) + .order(order) + .typeFullName(registerType(unaryExpr.getType.toQuotedString)) + .argumentIndex(order) + .lineNumber(line(parentUnit)) + .columnNumber(column(parentUnit)) + + def astForTypeRef(t: String, order: Int) = + Seq( Ast( - NewUnknown() - .code(s"goto ${line(gotoStmt.getTarget).getOrElse(gotoStmt.getTarget.toString())}") + NewTypeRef() + .code(if t.contains('.') then t.substring(t.lastIndexOf('.') + 1, t.length) + else t) .order(order) .argumentIndex(order) - .lineNumber(line(gotoStmt)) - .columnNumber(column(gotoStmt)) + .lineNumber(line(parentUnit)) + .columnNumber(column(parentUnit)) + .typeFullName(t) ) ) - controlTargets.put(gotoAst, gotoStmt.getTarget) - gotoAst - - private def astForSwitchWithDefaultAndCondition(switchStmt: SwitchStmt, order: Int): Ast = - val jimple = switchStmt.toString() - val totalTgts = switchStmt.getTargets.size() - val switch = NewControlStructure() - .controlStructureType(ControlStructureTypes.SWITCH) - .code(jimple.substring(0, jimple.indexOf("{") - 1)) - .lineNumber(line(switchStmt)) - .columnNumber(column(switchStmt)) + + val valueAsts = unaryExpr match + case instanceOfExpr: InstanceOfExpr => + val t = registerType(instanceOfExpr.getCheckType.toQuotedString) + astsForValue(op, 1, parentUnit) ++ astForTypeRef(t, 2) + case castExpr: CastExpr => + val t = registerType(castExpr.getCastType.toQuotedString) + astForTypeRef(t, 1) ++ astsForValue(op, 2, parentUnit) + case _ => astsForValue(op, 1, parentUnit) + + Ast(callBlock) + .withChildren(valueAsts) + .withArgEdges(callBlock, valueAsts.flatMap(_.root)) + end astForUnaryExpr + + private def createThisNode(method: ThisRef): Ast = + Ast( + NewIdentifier() + .name("this") + .code("this") + .typeFullName(registerType(method.getType.toQuotedString)) + .dynamicTypeHintFullName(Seq(registerType(method.getType.toQuotedString))) + .order(0) + .argumentIndex(0) + ) + + private def createThisNode(method: SootMethod, builder: NewNode): Ast = + createThisNode(method.makeRef(), builder) + + private def createThisNode(method: SootMethodRef, builder: NewNode): Ast = + if !method.isStatic || method.isConstructor then + val parentType = + registerType(Try(method.getDeclaringClass.getType.toQuotedString).getOrElse("ANY")) + Ast(builder match + case x: NewIdentifier => + x.name("this") + .code("this") + .typeFullName(parentType) + .order(0) + .argumentIndex(0) + .dynamicTypeHintFullName(Seq(parentType)) + case _: NewMethodParameterIn => + NodeBuilders.newThisParameterNode( + typeFullName = parentType, + dynamicTypeHintFullName = Seq(parentType), + line = line(Try(method.tryResolve()).getOrElse(null)) + ) + case x => x + ) + else + Ast() + + private def createParameterNode(parameterRef: ParameterRef, order: Int): Ast = + val name = s"@parameter${parameterRef.getIndex}" + Ast( + NewIdentifier() + .name(name) + .code(name) + .typeFullName(registerType(parameterRef.getType.toQuotedString)) + .order(order) + .argumentIndex(order) + ) + + /** Creates the AST for assignment statements keeping in mind Jimple is a 3-address code language. + */ + private def astsForDefinition(assignStmt: DefinitionStmt, order: Int): Seq[Ast] = + val initializer = assignStmt.getRightOp + val leftOp = assignStmt.getLeftOp + + val identifier = leftOp match + case x: soot.Local => Seq(astForLocal(x, 1, assignStmt)) + case x: FieldRef => Seq(astForFieldRef(x, 1, assignStmt)) + case x => astsForValue(x, 1, assignStmt) + val lhsCode = + identifier.flatMap(_.root).flatMap(_.properties.get(PropertyNames.CODE)).mkString + + val initAsts = astsForValue(initializer, 2, assignStmt) + val rhsCode = initAsts + .flatMap(_.root) + .map(_.properties.getOrElse(PropertyNames.CODE, "")) + .mkString(", ") + + val assignment = NewCall() + .name(Operators.assignment) + .methodFullName(Operators.assignment) + .code(s"$lhsCode = $rhsCode") + .dispatchType(DispatchTypes.STATIC_DISPATCH) + .order(order) + .argumentIndex(order) + .typeFullName(registerType(assignStmt.getLeftOp.getType.toQuotedString)) + val initializerAst = Seq(callAst(assignment, identifier ++ initAsts)) + initializerAst.toList + end astsForDefinition + + private def astsForIfStmt(ifStmt: IfStmt, order: Int): Seq[Ast] = + // bytecode/jimple ASTs are flat so there will not be nested bodies + val condition = astsForValue(ifStmt.getCondition, order, ifStmt) + controlTargets.put(condition, ifStmt.getTarget) + condition + + private def astsForGotoStmt(gotoStmt: GotoStmt, order: Int): Seq[Ast] = + // bytecode/jimple ASTs are flat so there will not be nested bodies + val gotoAst = Seq( + Ast( + NewUnknown() + .code(s"goto ${line(gotoStmt.getTarget).getOrElse(gotoStmt.getTarget.toString())}") .order(order) .argumentIndex(order) - - val conditionalAst = astsForValue(switchStmt.getKey, totalTgts + 1, switchStmt) - val defaultAst = Seq( - Ast( - NewJumpTarget() - .name("default") - .code("default:") - .order(totalTgts + 2) - .argumentIndex(totalTgts + 2) - .lineNumber(line(switchStmt.getDefaultTarget)) - .columnNumber(column(switchStmt.getDefaultTarget)) - ) + .lineNumber(line(gotoStmt)) + .columnNumber(column(gotoStmt)) + ) + ) + controlTargets.put(gotoAst, gotoStmt.getTarget) + gotoAst + + private def astForSwitchWithDefaultAndCondition(switchStmt: SwitchStmt, order: Int): Ast = + val jimple = switchStmt.toString() + val totalTgts = switchStmt.getTargets.size() + val switch = NewControlStructure() + .controlStructureType(ControlStructureTypes.SWITCH) + .code(jimple.substring(0, jimple.indexOf("{") - 1)) + .lineNumber(line(switchStmt)) + .columnNumber(column(switchStmt)) + .order(order) + .argumentIndex(order) + + val conditionalAst = astsForValue(switchStmt.getKey, totalTgts + 1, switchStmt) + val defaultAst = Seq( + Ast( + NewJumpTarget() + .name("default") + .code("default:") + .order(totalTgts + 2) + .argumentIndex(totalTgts + 2) + .lineNumber(line(switchStmt.getDefaultTarget)) + .columnNumber(column(switchStmt.getDefaultTarget)) + ) + ) + Ast(switch) + .withConditionEdge(switch, conditionalAst.flatMap(_.root).head) + .withChildren(conditionalAst ++ defaultAst) + end astForSwitchWithDefaultAndCondition + + private def astsForLookupSwitchStmt(lookupSwitchStmt: LookupSwitchStmt, order: Int): Seq[Ast] = + val totalTgts = lookupSwitchStmt.getTargets.size() + val switchAst = astForSwitchWithDefaultAndCondition(lookupSwitchStmt, order) + + val tgts = + for + i <- 0 until totalTgts + if lookupSwitchStmt.getTarget(i) != lookupSwitchStmt.getDefaultTarget + yield (lookupSwitchStmt.getLookupValue(i), lookupSwitchStmt.getTarget(i)) + val tgtAsts = tgts.map { case (lookup, target) => + Ast( + NewJumpTarget() + .name(s"case $lookup") + .code(s"case $lookup:") + .argumentIndex(lookup) + .order(lookup) + .lineNumber(line(target)) + .columnNumber(column(target)) ) - Ast(switch) - .withConditionEdge(switch, conditionalAst.flatMap(_.root).head) - .withChildren(conditionalAst ++ defaultAst) - end astForSwitchWithDefaultAndCondition - - private def astsForLookupSwitchStmt(lookupSwitchStmt: LookupSwitchStmt, order: Int): Seq[Ast] = - val totalTgts = lookupSwitchStmt.getTargets.size() - val switchAst = astForSwitchWithDefaultAndCondition(lookupSwitchStmt, order) - - val tgts = - for - i <- 0 until totalTgts - if lookupSwitchStmt.getTarget(i) != lookupSwitchStmt.getDefaultTarget - yield (lookupSwitchStmt.getLookupValue(i), lookupSwitchStmt.getTarget(i)) - val tgtAsts = tgts.map { case (lookup, target) => + } + + Seq( + switchAst + .withChildren(tgtAsts) + ) + end astsForLookupSwitchStmt + + private def astsForTableSwitchStmt(tableSwitchStmt: SwitchStmt, order: Int): Seq[Ast] = + val switchAst = astForSwitchWithDefaultAndCondition(tableSwitchStmt, order) + val tgtAsts = tableSwitchStmt.getTargets.asScala + .filter(x => tableSwitchStmt.getDefaultTarget != x) + .zipWithIndex + .map({ case (tgt, i) => Ast( NewJumpTarget() - .name(s"case $lookup") - .code(s"case $lookup:") - .argumentIndex(lookup) - .order(lookup) - .lineNumber(line(target)) - .columnNumber(column(target)) + .name(s"case $i") + .code(s"case $i:") + .argumentIndex(i) + .order(i) + .lineNumber(line(tgt)) + .columnNumber(column(tgt)) ) - } - - Seq( - switchAst - .withChildren(tgtAsts) - ) - end astsForLookupSwitchStmt - - private def astsForTableSwitchStmt(tableSwitchStmt: SwitchStmt, order: Int): Seq[Ast] = - val switchAst = astForSwitchWithDefaultAndCondition(tableSwitchStmt, order) - val tgtAsts = tableSwitchStmt.getTargets.asScala - .filter(x => tableSwitchStmt.getDefaultTarget != x) - .zipWithIndex - .map({ case (tgt, i) => - Ast( - NewJumpTarget() - .name(s"case $i") - .code(s"case $i:") - .argumentIndex(i) - .order(i) - .lineNumber(line(tgt)) - .columnNumber(column(tgt)) - ) - }) - .toSeq - - Seq( - switchAst - .withChildren(tgtAsts) - ) - end astsForTableSwitchStmt - - private def astsForThrowStmt(throwStmt: ThrowStmt, order: Int): Seq[Ast] = - val opAst = astsForValue(throwStmt.getOp, 1, throwStmt) - val throwNode = NewCall() - .name(".throw") - .methodFullName(".throw") - .lineNumber(line(throwStmt)) - .columnNumber(column(throwStmt)) - .code(s"throw new ${throwStmt.getOp.getType}()") + }) + .toSeq + + Seq( + switchAst + .withChildren(tgtAsts) + ) + end astsForTableSwitchStmt + + private def astsForThrowStmt(throwStmt: ThrowStmt, order: Int): Seq[Ast] = + val opAst = astsForValue(throwStmt.getOp, 1, throwStmt) + val throwNode = NewCall() + .name(".throw") + .methodFullName(".throw") + .lineNumber(line(throwStmt)) + .columnNumber(column(throwStmt)) + .code(s"throw new ${throwStmt.getOp.getType}()") + .order(order) + .argumentIndex(order) + .dispatchType(DispatchTypes.STATIC_DISPATCH) + Seq( + Ast(throwNode) + .withChildren(opAst) + ) + + private def astsForMonitorStmt(monitorStmt: MonitorStmt, order: Int): Seq[Ast] = + val opAst = astsForValue(monitorStmt.getOp, 1, monitorStmt) + val typeString = opAst.flatMap(_.root).map(_.properties(PropertyNames.CODE)).mkString + val code = monitorStmt match + case _: EnterMonitorStmt => s"entermonitor $typeString" + case _: ExitMonitorStmt => s"exitmonitor $typeString" + case _ => s"monitor $typeString" + Seq( + Ast( + NewUnknown() .order(order) .argumentIndex(order) - .dispatchType(DispatchTypes.STATIC_DISPATCH) - Seq( - Ast(throwNode) - .withChildren(opAst) - ) - - private def astsForMonitorStmt(monitorStmt: MonitorStmt, order: Int): Seq[Ast] = - val opAst = astsForValue(monitorStmt.getOp, 1, monitorStmt) - val typeString = opAst.flatMap(_.root).map(_.properties(PropertyNames.CODE)).mkString - val code = monitorStmt match - case _: EnterMonitorStmt => s"entermonitor $typeString" - case _: ExitMonitorStmt => s"exitmonitor $typeString" - case _ => s"monitor $typeString" - Seq( - Ast( - NewUnknown() - .order(order) - .argumentIndex(order) - .code(code) - .lineNumber(line(monitorStmt)) - .columnNumber(column(monitorStmt)) - ).withChildren(opAst) + .code(code) + .lineNumber(line(monitorStmt)) + .columnNumber(column(monitorStmt)) + ).withChildren(opAst) + ) + + private def astForUnknownStmt(stmt: Unit, maybeOp: Option[Value], order: Int): Ast = + val opAst = maybeOp match + case Some(op) => astsForValue(op, 1, stmt) + case None => Seq() + val unknown = NewUnknown() + .order(order) + .code(stmt.toString()) + .lineNumber(line(stmt)) + .columnNumber(column(stmt)) + .typeFullName(registerType("void")) + Ast(unknown) + .withChildren(opAst) + + private def astsForReturnNode(returnStmt: ReturnStmt, order: Int): Seq[Ast] = + val astChildren = astsForValue(returnStmt.getOp, 1, returnStmt) + val returnNode = NewReturn() + .argumentIndex(order) + .order(order) + .code( + s"return ${astChildren.flatMap(_.root).map(_.properties(PropertyNames.CODE)).mkString(" ")};" ) + .lineNumber(line(returnStmt)) + .columnNumber(column(returnStmt)) - private def astForUnknownStmt(stmt: Unit, maybeOp: Option[Value], order: Int): Ast = - val opAst = maybeOp match - case Some(op) => astsForValue(op, 1, stmt) - case None => Seq() - val unknown = NewUnknown() - .order(order) - .code(stmt.toString()) - .lineNumber(line(stmt)) - .columnNumber(column(stmt)) - .typeFullName(registerType("void")) - Ast(unknown) - .withChildren(opAst) - - private def astsForReturnNode(returnStmt: ReturnStmt, order: Int): Seq[Ast] = - val astChildren = astsForValue(returnStmt.getOp, 1, returnStmt) - val returnNode = NewReturn() - .argumentIndex(order) - .order(order) - .code( - s"return ${astChildren.flatMap(_.root).map(_.properties(PropertyNames.CODE)).mkString(" ")};" - ) - .lineNumber(line(returnStmt)) - .columnNumber(column(returnStmt)) + Seq( + Ast(returnNode) + .withChildren(astChildren) + .withArgEdges(returnNode, astChildren.flatMap(_.root)) + ) - Seq( - Ast(returnNode) - .withChildren(astChildren) - .withArgEdges(returnNode, astChildren.flatMap(_.root)) - ) - - private def astsForReturnVoidNode(returnVoidStmt: ReturnVoidStmt, order: Int): Seq[Ast] = - Seq( - Ast( - NewReturn() - .argumentIndex(order) - .order(order) - .code(s"return;") - .lineNumber(line(returnVoidStmt)) - .columnNumber(column(returnVoidStmt)) - ) + private def astsForReturnVoidNode(returnVoidStmt: ReturnVoidStmt, order: Int): Seq[Ast] = + Seq( + Ast( + NewReturn() + .argumentIndex(order) + .order(order) + .code(s"return;") + .lineNumber(line(returnVoidStmt)) + .columnNumber(column(returnVoidStmt)) ) - - private def astForFieldRef(fieldRef: FieldRef, order: Int, parentUnit: soot.Unit): Ast = - val leftOpString = fieldRef match - case x: StaticFieldRef => x.getFieldRef.declaringClass().toString - case x: InstanceFieldRef => x.getBase.toString() - case _ => fieldRef.getFieldRef.declaringClass().toString - val leftOpType = fieldRef match - case x: StaticFieldRef => x.getFieldRef.declaringClass().getType - case x: InstanceFieldRef => x.getBase.getType - case _ => fieldRef.getFieldRef.declaringClass().getType - - val fieldAccessBlock = NewCall() - .name(Operators.fieldAccess) - .code(s"$leftOpString.${fieldRef.getFieldRef.name()}") - .typeFullName(registerType(fieldRef.getType.toQuotedString)) - .methodFullName(Operators.fieldAccess) - .dispatchType(DispatchTypes.STATIC_DISPATCH) + ) + + private def astForFieldRef(fieldRef: FieldRef, order: Int, parentUnit: soot.Unit): Ast = + val leftOpString = fieldRef match + case x: StaticFieldRef => x.getFieldRef.declaringClass().toString + case x: InstanceFieldRef => x.getBase.toString() + case _ => fieldRef.getFieldRef.declaringClass().toString + val leftOpType = fieldRef match + case x: StaticFieldRef => x.getFieldRef.declaringClass().getType + case x: InstanceFieldRef => x.getBase.getType + case _ => fieldRef.getFieldRef.declaringClass().getType + + val fieldAccessBlock = NewCall() + .name(Operators.fieldAccess) + .code(s"$leftOpString.${fieldRef.getFieldRef.name()}") + .typeFullName(registerType(fieldRef.getType.toQuotedString)) + .methodFullName(Operators.fieldAccess) + .dispatchType(DispatchTypes.STATIC_DISPATCH) + .order(order) + .argumentIndex(order) + .lineNumber(line(parentUnit)) + .columnNumber(column(parentUnit)) + + val argAsts = Seq( + NewIdentifier() + .order(1) + .argumentIndex(1) + .lineNumber(line(parentUnit)) + .columnNumber(column(parentUnit)) + .name(leftOpString) + .code(leftOpString) + .typeFullName(registerType(leftOpType.toQuotedString)), + NewFieldIdentifier() + .order(2) + .argumentIndex(2) + .lineNumber(line(parentUnit)) + .columnNumber(column(parentUnit)) + .canonicalName(fieldRef.getFieldRef.name()) + .code(fieldRef.getFieldRef.name()) + ).map(Ast(_)) + + Ast(fieldAccessBlock) + .withChildren(argAsts) + .withArgEdges(fieldAccessBlock, argAsts.flatMap(_.root)) + end astForFieldRef + + private def astForCaughtExceptionRef( + caughtException: CaughtExceptionRef, + order: Int, + parentUnit: soot.Unit + ): Ast = + Ast( + NewIdentifier() .order(order) .argumentIndex(order) .lineNumber(line(parentUnit)) .columnNumber(column(parentUnit)) + .name(caughtException.toString()) + .code(caughtException.toString()) + .typeFullName(registerType(caughtException.getType.toQuotedString)) + ) + + private def astForConstantExpr(constant: Constant, order: Int): Ast = + constant match + case x: ClassConstant => + Ast( + NewLiteral() + .order(order) + .argumentIndex(order) + .code(s"${x.value.parseAsJavaType}.class") + .typeFullName(registerType(x.getType.toQuotedString)) + ) + case _: NullConstant => + Ast( + NewLiteral() + .order(order) + .argumentIndex(order) + .code("null") + .typeFullName(registerType("null")) + ) + case _ => + Ast( + NewLiteral() + .order(order) + .argumentIndex(order) + .code(constant.toString) + .typeFullName(registerType(constant.getType.toQuotedString)) + ) - val argAsts = Seq( - NewIdentifier() - .order(1) - .argumentIndex(1) - .lineNumber(line(parentUnit)) - .columnNumber(column(parentUnit)) - .name(leftOpString) - .code(leftOpString) - .typeFullName(registerType(leftOpType.toQuotedString)), - NewFieldIdentifier() - .order(2) - .argumentIndex(2) - .lineNumber(line(parentUnit)) - .columnNumber(column(parentUnit)) - .canonicalName(fieldRef.getFieldRef.name()) - .code(fieldRef.getFieldRef.name()) - ).map(Ast(_)) - - Ast(fieldAccessBlock) - .withChildren(argAsts) - .withArgEdges(fieldAccessBlock, argAsts.flatMap(_.root)) - end astForFieldRef - - private def astForCaughtExceptionRef( - caughtException: CaughtExceptionRef, - order: Int, - parentUnit: soot.Unit - ): Ast = - Ast( - NewIdentifier() - .order(order) - .argumentIndex(order) - .lineNumber(line(parentUnit)) - .columnNumber(column(parentUnit)) - .name(caughtException.toString()) - .code(caughtException.toString()) - .typeFullName(registerType(caughtException.getType.toQuotedString)) - ) - - private def astForConstantExpr(constant: Constant, order: Int): Ast = - constant match - case x: ClassConstant => - Ast( - NewLiteral() - .order(order) - .argumentIndex(order) - .code(s"${x.value.parseAsJavaType}.class") - .typeFullName(registerType(x.getType.toQuotedString)) - ) - case _: NullConstant => - Ast( - NewLiteral() - .order(order) - .argumentIndex(order) - .code("null") - .typeFullName(registerType("null")) - ) - case _ => - Ast( - NewLiteral() - .order(order) - .argumentIndex(order) - .code(constant.toString) - .typeFullName(registerType(constant.getType.toQuotedString)) - ) - - private def callAst(rootNode: NewNode, args: Seq[Ast]): Ast = - Ast(rootNode) - .withChildren(args) - .withArgEdges(rootNode, args.flatMap(_.root)) - - private def astsForModifiers(methodDeclaration: SootMethod): Seq[Ast] = - Seq( - if methodDeclaration.isStatic then Some(ModifierTypes.STATIC) else None, - if methodDeclaration.isPublic then Some(ModifierTypes.PUBLIC) else None, - if methodDeclaration.isProtected then Some(ModifierTypes.PROTECTED) else None, - if methodDeclaration.isPrivate then Some(ModifierTypes.PRIVATE) else None, - if methodDeclaration.isAbstract then Some(ModifierTypes.ABSTRACT) else None, - if methodDeclaration.isConstructor then Some(ModifierTypes.CONSTRUCTOR) else None, - if !methodDeclaration.isFinal && !methodDeclaration.isStatic && methodDeclaration.isPublic - then - Some(ModifierTypes.VIRTUAL) - else None, - if methodDeclaration.isSynchronized then Some("SYNCHRONIZED") else None - ).flatten.map { modifier => - Ast(NewModifier().modifierType(modifier).code(modifier.toLowerCase)) - } - - private def astsForModifiers(classDeclaration: SootClass): Seq[Ast] = - Seq( - if classDeclaration.isStatic then Some(ModifierTypes.STATIC) else None, - if classDeclaration.isPublic then Some(ModifierTypes.PUBLIC) else None, - if classDeclaration.isProtected then Some(ModifierTypes.PROTECTED) else None, - if classDeclaration.isPrivate then Some(ModifierTypes.PRIVATE) else None, - if classDeclaration.isAbstract then Some(ModifierTypes.ABSTRACT) else None, - if classDeclaration.isInterface then Some("INTERFACE") else None, - if !classDeclaration.isFinal && !classDeclaration.isStatic && classDeclaration.isPublic - then - Some(ModifierTypes.VIRTUAL) - else None, - if classDeclaration.isSynchronized then Some("SYNCHRONIZED") else None - ).flatten.map { modifier => - Ast(NewModifier().modifierType(modifier).code(modifier.toLowerCase)) - } - - private def astForMethodReturn(methodDeclaration: SootMethod): Ast = - val typeFullName = registerType(methodDeclaration.getReturnType.toQuotedString) - val methodReturnNode = NodeBuilders - .newMethodReturnNode(typeFullName, None, line(methodDeclaration), None) - .order(methodDeclaration.getParameterCount + 2) - Ast(methodReturnNode) - - private def createMethodNode(methodDeclaration: SootMethod, typeDecl: RefType, childNum: Int) = - val fullName = methodFullName(typeDecl, methodDeclaration) - val methodDeclType = registerType(methodDeclaration.getReturnType.toQuotedString) - val code = if !methodDeclaration.isConstructor then - s"$methodDeclType ${methodDeclaration.getName}${paramListSignature(methodDeclaration, withParams = true)}" - else - s"${typeDecl.getClassName}${paramListSignature(methodDeclaration, withParams = true)}" - NewMethod() - .name(methodDeclaration.getName) - .fullName(fullName) - .code(code) - .signature(methodDeclType + paramListSignature(methodDeclaration)) - .isExternal(false) - .order(childNum) - .filename(filename) - .astParentType(NodeTypes.TYPE_DECL) - .astParentFullName(typeDecl.toQuotedString) - .lineNumber(line(methodDeclaration)) - .columnNumber(column(methodDeclaration)) - end createMethodNode - - private def methodFullName(typeDecl: RefType, methodDeclaration: SootMethod): String = - val typeName = registerType(typeDecl.toQuotedString) - val returnType = registerType(methodDeclaration.getReturnType.toQuotedString) - val methodName = methodDeclaration.getName - s"$typeName.$methodName:$returnType${paramListSignature(methodDeclaration)}" - - private def paramListSignature(methodDeclaration: SootMethod, withParams: Boolean = false) = - val paramTypes = - methodDeclaration.getParameterTypes.asScala.map(x => registerType(x.toQuotedString)) - - val paramNames = - if !methodDeclaration.isPhantom && Try(methodDeclaration.retrieveActiveBody()).isSuccess - then - methodDeclaration.retrieveActiveBody().getParameterLocals.asScala.map(_.getName) - else - paramTypes.zipWithIndex.map(x => s"param${x._2 + 1}") - if !withParams then - "(" + paramTypes.mkString(",") + ")" + private def callAst(rootNode: NewNode, args: Seq[Ast]): Ast = + Ast(rootNode) + .withChildren(args) + .withArgEdges(rootNode, args.flatMap(_.root)) + + private def astsForModifiers(methodDeclaration: SootMethod): Seq[Ast] = + Seq( + if methodDeclaration.isStatic then Some(ModifierTypes.STATIC) else None, + if methodDeclaration.isPublic then Some(ModifierTypes.PUBLIC) else None, + if methodDeclaration.isProtected then Some(ModifierTypes.PROTECTED) else None, + if methodDeclaration.isPrivate then Some(ModifierTypes.PRIVATE) else None, + if methodDeclaration.isAbstract then Some(ModifierTypes.ABSTRACT) else None, + if methodDeclaration.isConstructor then Some(ModifierTypes.CONSTRUCTOR) else None, + if !methodDeclaration.isFinal && !methodDeclaration.isStatic && methodDeclaration.isPublic + then + Some(ModifierTypes.VIRTUAL) + else None, + if methodDeclaration.isSynchronized then Some("SYNCHRONIZED") else None + ).flatten.map { modifier => + Ast(NewModifier().modifierType(modifier).code(modifier.toLowerCase)) + } + + private def astsForModifiers(classDeclaration: SootClass): Seq[Ast] = + Seq( + if classDeclaration.isStatic then Some(ModifierTypes.STATIC) else None, + if classDeclaration.isPublic then Some(ModifierTypes.PUBLIC) else None, + if classDeclaration.isProtected then Some(ModifierTypes.PROTECTED) else None, + if classDeclaration.isPrivate then Some(ModifierTypes.PRIVATE) else None, + if classDeclaration.isAbstract then Some(ModifierTypes.ABSTRACT) else None, + if classDeclaration.isInterface then Some("INTERFACE") else None, + if !classDeclaration.isFinal && !classDeclaration.isStatic && classDeclaration.isPublic + then + Some(ModifierTypes.VIRTUAL) + else None, + if classDeclaration.isSynchronized then Some("SYNCHRONIZED") else None + ).flatten.map { modifier => + Ast(NewModifier().modifierType(modifier).code(modifier.toLowerCase)) + } + + private def astForMethodReturn(methodDeclaration: SootMethod): Ast = + val typeFullName = registerType(methodDeclaration.getReturnType.toQuotedString) + val methodReturnNode = NodeBuilders + .newMethodReturnNode(typeFullName, None, line(methodDeclaration), None) + .order(methodDeclaration.getParameterCount + 2) + Ast(methodReturnNode) + + private def createMethodNode(methodDeclaration: SootMethod, typeDecl: RefType, childNum: Int) = + val fullName = methodFullName(typeDecl, methodDeclaration) + val methodDeclType = registerType(methodDeclaration.getReturnType.toQuotedString) + val code = if !methodDeclaration.isConstructor then + s"$methodDeclType ${methodDeclaration.getName}${paramListSignature(methodDeclaration, withParams = true)}" + else + s"${typeDecl.getClassName}${paramListSignature(methodDeclaration, withParams = true)}" + NewMethod() + .name(methodDeclaration.getName) + .fullName(fullName) + .code(code) + .signature(methodDeclType + paramListSignature(methodDeclaration)) + .isExternal(false) + .order(childNum) + .filename(filename) + .astParentType(NodeTypes.TYPE_DECL) + .astParentFullName(typeDecl.toQuotedString) + .lineNumber(line(methodDeclaration)) + .columnNumber(column(methodDeclaration)) + end createMethodNode + + private def methodFullName(typeDecl: RefType, methodDeclaration: SootMethod): String = + val typeName = registerType(typeDecl.toQuotedString) + val returnType = registerType(methodDeclaration.getReturnType.toQuotedString) + val methodName = methodDeclaration.getName + s"$typeName.$methodName:$returnType${paramListSignature(methodDeclaration)}" + + private def paramListSignature(methodDeclaration: SootMethod, withParams: Boolean = false) = + val paramTypes = + methodDeclaration.getParameterTypes.asScala.map(x => registerType(x.toQuotedString)) + + val paramNames = + if !methodDeclaration.isPhantom && Try(methodDeclaration.retrieveActiveBody()).isSuccess + then + methodDeclaration.retrieveActiveBody().getParameterLocals.asScala.map(_.getName) else - "(" + paramTypes.zip(paramNames).map(x => s"${x._1} ${x._2}").mkString(", ") + ")" + paramTypes.zipWithIndex.map(x => s"param${x._2 + 1}") + if !withParams then + "(" + paramTypes.mkString(",") + ")" + else + "(" + paramTypes.zip(paramNames).map(x => s"${x._1} ${x._2}").mkString(", ") + ")" end AstCreator object AstCreator: - def line(node: Host): Option[Integer] = - if node == null then None - else if node.getJavaSourceStartLineNumber == -1 then None - else Option(node.getJavaSourceStartLineNumber) - - def column(node: Host): Option[Integer] = - if node == null then None - else if node.getJavaSourceStartColumnNumber == -1 then None - else Option(node.getJavaSourceStartColumnNumber) - - def withOrder[T <: Any, X](nodeList: java.util.List[T])(f: (T, Int) => X): Seq[X] = - nodeList.asScala.zipWithIndex.map { case (x, i) => - f(x, i + 1) - }.toSeq - - def withOrder[T <: Any, X](nodeList: Iterable[T])(f: (T, Int) => X): Seq[X] = - nodeList.zipWithIndex.map { case (x, i) => - f(x, i + 1) - }.toSeq + def line(node: Host): Option[Integer] = + if node == null then None + else if node.getJavaSourceStartLineNumber == -1 then None + else Option(node.getJavaSourceStartLineNumber) + + def column(node: Host): Option[Integer] = + if node == null then None + else if node.getJavaSourceStartColumnNumber == -1 then None + else Option(node.getJavaSourceStartColumnNumber) + + def withOrder[T <: Any, X](nodeList: java.util.List[T])(f: (T, Int) => X): Seq[X] = + nodeList.asScala.zipWithIndex.map { case (x, i) => + f(x, i + 1) + }.toSeq + + def withOrder[T <: Any, X](nodeList: Iterable[T])(f: (T, Int) => X): Seq[X] = + nodeList.zipWithIndex.map { case (x, i) => + f(x, i + 1) + }.toSeq end AstCreator /** String extensions for strings describing JVM operators. */ implicit class JvmStringOpts(s: String): - /** Parses the string as a ASM Java type descriptor and returns a fully qualified type. Also - * converts symbols such as I to int. - * @return - */ - def parseAsJavaType: String = Type.getType(s).getClassName.replaceAll("/", ".") + /** Parses the string as a ASM Java type descriptor and returns a fully qualified type. Also + * converts symbols such as I to int. + * @return + */ + def parseAsJavaType: String = Type.getType(s).getClassName.replaceAll("/", ".") diff --git a/platform/frontends/jimple2cpg/src/main/scala/io/appthreat/jimple2cpg/passes/ConfigFileCreationPass.scala b/platform/frontends/jimple2cpg/src/main/scala/io/appthreat/jimple2cpg/passes/ConfigFileCreationPass.scala index d2dccc14..f977b783 100644 --- a/platform/frontends/jimple2cpg/src/main/scala/io/appthreat/jimple2cpg/passes/ConfigFileCreationPass.scala +++ b/platform/frontends/jimple2cpg/src/main/scala/io/appthreat/jimple2cpg/passes/ConfigFileCreationPass.scala @@ -6,32 +6,32 @@ import io.shiftleft.codepropertygraph.Cpg class ConfigFileCreationPass(cpg: Cpg) extends XConfigFileCreationPass(cpg): - override val configFileFilters: List[File => Boolean] = List( - extensionFilter(".properties"), - // Velocity files, see https://velocity.apache.org - extensionFilter(".vm"), - // For Terraform secrets - extensionFilter(".tf"), - extensionFilter(".tfvars"), - // PLAY - pathEndFilter("routes"), - pathEndFilter("application.conf"), - // SERVLET - pathEndFilter("web.xml"), - // JSF - pathEndFilter("faces-config.xml"), - // STRUTS - pathEndFilter("struts.xml"), - // DIRECT WEB REMOTING - pathEndFilter("dwr.xml"), - // BUILD SYSTEM - pathEndFilter("build.gradle"), - pathEndFilter("build.gradle.kts"), - // ANDROID - pathEndFilter("AndroidManifest.xml"), - // Bom - pathEndFilter("bom.json"), - pathEndFilter(".cdx.json"), - pathEndFilter("chennai.json") - ) + override val configFileFilters: List[File => Boolean] = List( + extensionFilter(".properties"), + // Velocity files, see https://velocity.apache.org + extensionFilter(".vm"), + // For Terraform secrets + extensionFilter(".tf"), + extensionFilter(".tfvars"), + // PLAY + pathEndFilter("routes"), + pathEndFilter("application.conf"), + // SERVLET + pathEndFilter("web.xml"), + // JSF + pathEndFilter("faces-config.xml"), + // STRUTS + pathEndFilter("struts.xml"), + // DIRECT WEB REMOTING + pathEndFilter("dwr.xml"), + // BUILD SYSTEM + pathEndFilter("build.gradle"), + pathEndFilter("build.gradle.kts"), + // ANDROID + pathEndFilter("AndroidManifest.xml"), + // Bom + pathEndFilter("bom.json"), + pathEndFilter(".cdx.json"), + pathEndFilter("chennai.json") + ) end ConfigFileCreationPass diff --git a/platform/frontends/jimple2cpg/src/main/scala/io/appthreat/jimple2cpg/passes/DeclarationRefPass.scala b/platform/frontends/jimple2cpg/src/main/scala/io/appthreat/jimple2cpg/passes/DeclarationRefPass.scala index 1cf94125..68baa82f 100644 --- a/platform/frontends/jimple2cpg/src/main/scala/io/appthreat/jimple2cpg/passes/DeclarationRefPass.scala +++ b/platform/frontends/jimple2cpg/src/main/scala/io/appthreat/jimple2cpg/passes/DeclarationRefPass.scala @@ -11,12 +11,12 @@ import io.shiftleft.semanticcpg.language.* */ class DeclarationRefPass(atom: Cpg) extends ConcurrentWriterCpgPass[Method](atom): - override def generateParts(): Array[Method] = atom.method.toArray + override def generateParts(): Array[Method] = atom.method.toArray - override def runOnPart(builder: DiffGraphBuilder, part: Method): Unit = - val identifiers = part.ast.isIdentifier.toList - val declarations = - (part.parameter ++ part.block.astChildren.isLocal).collectAll[Declaration].l - declarations.foreach(d => - identifiers.nameExact(d.name).foreach(builder.addEdge(_, d, EdgeTypes.REF)) - ) + override def runOnPart(builder: DiffGraphBuilder, part: Method): Unit = + val identifiers = part.ast.isIdentifier.toList + val declarations = + (part.parameter ++ part.block.astChildren.isLocal).collectAll[Declaration].l + declarations.foreach(d => + identifiers.nameExact(d.name).foreach(builder.addEdge(_, d, EdgeTypes.REF)) + ) diff --git a/platform/frontends/jimple2cpg/src/main/scala/io/appthreat/jimple2cpg/passes/SootAstCreationPass.scala b/platform/frontends/jimple2cpg/src/main/scala/io/appthreat/jimple2cpg/passes/SootAstCreationPass.scala index e89c971d..3714652f 100644 --- a/platform/frontends/jimple2cpg/src/main/scala/io/appthreat/jimple2cpg/passes/SootAstCreationPass.scala +++ b/platform/frontends/jimple2cpg/src/main/scala/io/appthreat/jimple2cpg/passes/SootAstCreationPass.scala @@ -12,18 +12,18 @@ import soot.{Scene, SootClass, SourceLocator} */ class SootAstCreationPass(cpg: Cpg, config: Config) extends ConcurrentWriterCpgPass[SootClass](cpg): - val global: Global = new Global() - private val logger = LoggerFactory.getLogger(classOf[AstCreationPass]) + val global: Global = new Global() + private val logger = LoggerFactory.getLogger(classOf[AstCreationPass]) - override def generateParts(): Array[? <: AnyRef] = Scene.v().getApplicationClasses.toArray() + override def generateParts(): Array[? <: AnyRef] = Scene.v().getApplicationClasses.toArray() - override def runOnPart(builder: DiffGraphBuilder, part: SootClass): Unit = - val jimpleFile = SourceLocator.v().getSourceForClass(part.getName) - try - // sootClass.setApplicationClass() - val localDiff = - new AstCreator(jimpleFile, part, global)(config.schemaValidation).createAst() - builder.absorb(localDiff) - catch - case e: Exception => - logger.warn(s"Cannot parse: $part", e) + override def runOnPart(builder: DiffGraphBuilder, part: SootClass): Unit = + val jimpleFile = SourceLocator.v().getSourceForClass(part.getName) + try + // sootClass.setApplicationClass() + val localDiff = + new AstCreator(jimpleFile, part, global)(config.schemaValidation).createAst() + builder.absorb(localDiff) + catch + case e: Exception => + logger.warn(s"Cannot parse: $part", e) diff --git a/platform/frontends/jimple2cpg/src/main/scala/io/appthreat/jimple2cpg/util/ProgramHandlingUtil.scala b/platform/frontends/jimple2cpg/src/main/scala/io/appthreat/jimple2cpg/util/ProgramHandlingUtil.scala index 0553887c..a128ab7a 100644 --- a/platform/frontends/jimple2cpg/src/main/scala/io/appthreat/jimple2cpg/util/ProgramHandlingUtil.scala +++ b/platform/frontends/jimple2cpg/src/main/scala/io/appthreat/jimple2cpg/util/ProgramHandlingUtil.scala @@ -13,191 +13,191 @@ import scala.util.{Failure, Left, Success, Try} */ object ProgramHandlingUtil: - private val logger = LoggerFactory.getLogger(ProgramHandlingUtil.getClass) - - /** Common properties of a File and ZipEntry, used to determine whether a file in a directory or - * an entry in an archive is worth emitting/extracting - */ - sealed class Entry(entry: Either[File, ZipEntry]): - - def this(file: File) = this(Left(file)) - def this(entry: ZipEntry) = this(Right(entry)) - private def file: File = entry.fold(identity, e => File(e.getName)) - def name: String = file.name - def extension: Option[String] = file.extension - def isDirectory: Boolean = entry.fold(_.isDirectory, _.isDirectory) - def maybeRegularFile(): Boolean = entry.fold(_.isRegularFile, !_.isDirectory) - - /** Determines whether a zip entry is potentially malicious. - * @return - * whether the entry is a ZipEntry and uses '..' in it's components - */ - // Note that we consider either type of path separator as although the spec say that only - // unix separators are to be used, zip files in the wild may vary. - def isZipSlip: Boolean = entry.fold(_ => false, _.getName.split("[/\\\\]").contains("..")) - - /** Process files that may lead to more files to process or to emit a resulting value of [[A]] - * - * @param src - * The file/directory to traverse - * @param emitOrUnpack - * A function that takes a file and either emits a value or returns more files to traverse - * @tparam A - * The type of emitted values + private val logger = LoggerFactory.getLogger(ProgramHandlingUtil.getClass) + + /** Common properties of a File and ZipEntry, used to determine whether a file in a directory or + * an entry in an archive is worth emitting/extracting + */ + sealed class Entry(entry: Either[File, ZipEntry]): + + def this(file: File) = this(Left(file)) + def this(entry: ZipEntry) = this(Right(entry)) + private def file: File = entry.fold(identity, e => File(e.getName)) + def name: String = file.name + def extension: Option[String] = file.extension + def isDirectory: Boolean = entry.fold(_.isDirectory, _.isDirectory) + def maybeRegularFile(): Boolean = entry.fold(_.isRegularFile, !_.isDirectory) + + /** Determines whether a zip entry is potentially malicious. * @return - * The emitted values + * whether the entry is a ZipEntry and uses '..' in it's components */ - private def unfoldArchives[A]( - src: File, - emitOrUnpack: File => Either[A, List[File]] - ): IterableOnce[A] = - // TODO: add recursion depth limit - emitOrUnpack(src) match - case Left(a) => Seq(a) - case Right(disposeFiles) => disposeFiles.flatMap(x => unfoldArchives(x, emitOrUnpack)) - - /** Find
.class
files, including those inside archives. + // Note that we consider either type of path separator as although the spec say that only + // unix separators are to be used, zip files in the wild may vary. + def isZipSlip: Boolean = entry.fold(_ => false, _.getName.split("[/\\\\]").contains("..")) + + /** Process files that may lead to more files to process or to emit a resulting value of [[A]] + * + * @param src + * The file/directory to traverse + * @param emitOrUnpack + * A function that takes a file and either emits a value or returns more files to traverse + * @tparam A + * The type of emitted values + * @return + * The emitted values + */ + private def unfoldArchives[A]( + src: File, + emitOrUnpack: File => Either[A, List[File]] + ): IterableOnce[A] = + // TODO: add recursion depth limit + emitOrUnpack(src) match + case Left(a) => Seq(a) + case Right(disposeFiles) => disposeFiles.flatMap(x => unfoldArchives(x, emitOrUnpack)) + + /** Find
.class
files, including those inside archives. + * + * @param src + * The file/directory to search. + * @param tmpDir + * A temporary directory for extracted archives + * @param isArchive + * Whether an entry is an archive to extract + * @param isClass + * Whether an entry is a class file + * @return + * The list of class files found, which may either be in [[src]] or in an extracted archive + * under [[tmpDir]] + */ + private def extractClassesToTmp( + src: File, + tmpDir: File, + isArchive: Entry => Boolean, + isClass: Entry => Boolean, + recurse: Boolean + ): IterableOnce[ClassFile] = + + def shouldExtract(e: Entry) = + !e.isZipSlip && e.maybeRegularFile() && (isArchive(e) || isClass(e)) + unfoldArchives( + src, + { + case f if isClass(Entry(f)) => + Left(ClassFile(f)) + case f if f.isDirectory() => + val files = f.listRecursively.filterNot(_.isDirectory).toList + Right(files) + case f if isArchive(Entry(f)) && (recurse || f == src) => + val xTmp = File.newTemporaryDirectory("extract-archive-", parent = Some(tmpDir)) + val unzipDirs = Try(f.unzipTo(xTmp, e => shouldExtract(Entry(e)))) match + case Success(dir) => List(dir) + case Failure(e) => + logger.warn(s"Failed to extract archive", e) + List.empty + Right(unzipDirs) + case _ => + Right(List.empty) + } + ) + end extractClassesToTmp + + object ClassFile: + private def getPackagePathFromByteCode(is: InputStream): Option[String] = + val cr = new ClassReader(is) + sealed class ClassNameVisitor extends ClassVisitor(Opcodes.ASM9): + var path: Option[String] = None + override def visit( + version: Int, + access: Int, + name: String, + signature: String, + superName: String, + interfaces: Array[String] + ): Unit = + path = Some(name) + val rootVisitor = new ClassNameVisitor() + cr.accept(rootVisitor, SKIP_CODE) + rootVisitor.path + + /** Attempt to retrieve the package path from JVM bytecode. * - * @param src - * The file/directory to search. - * @param tmpDir - * A temporary directory for extracted archives - * @param isArchive - * Whether an entry is an archive to extract - * @param isClass - * Whether an entry is a class file + * @param file + * The class file * @return - * The list of class files found, which may either be in [[src]] or in an extracted archive - * under [[tmpDir]] + * The package path if successfully retrieved */ - private def extractClassesToTmp( - src: File, - tmpDir: File, - isArchive: Entry => Boolean, - isClass: Entry => Boolean, - recurse: Boolean - ): IterableOnce[ClassFile] = - - def shouldExtract(e: Entry) = - !e.isZipSlip && e.maybeRegularFile() && (isArchive(e) || isClass(e)) - unfoldArchives( - src, - { - case f if isClass(Entry(f)) => - Left(ClassFile(f)) - case f if f.isDirectory() => - val files = f.listRecursively.filterNot(_.isDirectory).toList - Right(files) - case f if isArchive(Entry(f)) && (recurse || f == src) => - val xTmp = File.newTemporaryDirectory("extract-archive-", parent = Some(tmpDir)) - val unzipDirs = Try(f.unzipTo(xTmp, e => shouldExtract(Entry(e)))) match - case Success(dir) => List(dir) - case Failure(e) => - logger.warn(s"Failed to extract archive", e) - List.empty - Right(unzipDirs) - case _ => - Right(List.empty) - } - ) - end extractClassesToTmp - - object ClassFile: - private def getPackagePathFromByteCode(is: InputStream): Option[String] = - val cr = new ClassReader(is) - sealed class ClassNameVisitor extends ClassVisitor(Opcodes.ASM9): - var path: Option[String] = None - override def visit( - version: Int, - access: Int, - name: String, - signature: String, - superName: String, - interfaces: Array[String] - ): Unit = - path = Some(name) - val rootVisitor = new ClassNameVisitor() - cr.accept(rootVisitor, SKIP_CODE) - rootVisitor.path - - /** Attempt to retrieve the package path from JVM bytecode. - * - * @param file - * The class file - * @return - * The package path if successfully retrieved - */ - private def getPackagePathFromByteCode(file: File): Option[String] = - Try(file.fileInputStream.apply(getPackagePathFromByteCode)) - .recover { case e: Throwable => - logger.debug(s"Error reading class file ${file.canonicalPath}", e) - None - } - .getOrElse(None) - end ClassFile - sealed class ClassFile(val file: File, val packagePath: Option[String]): - def this(file: File) = this(file, ClassFile.getPackagePathFromByteCode(file)) - - private val components: Option[Array[String]] = packagePath.map(_.split("/")) - - val fullyQualifiedClassName: Option[String] = components.map(_.mkString(".")) - - /** Copy the class file to its package path relative to [[destDir]]. This will overwrite a - * class file at the destination if it exists. - * @param destDir - * The directory in which to place the class file - * @return - * The class file at the destination if the package path could be retrieved from the its - * bytecode - */ - def copyToPackageLayoutIn(destDir: File): Option[ClassFile] = - packagePath - .map { path => - val destClass = destDir / s"$path.class" - if destClass.exists() then - logger.warn(s"Overwriting class file: ${destClass.path.toAbsolutePath}") - destClass.parent.createDirectories() - ClassFile( - file.copyTo(destClass)(File.CopyOptions(overwrite = true)), - packagePath - ) - } - .orElse { - logger.warn( - s"Missing package path for ${file.canonicalPath}. Failed to copy to ${destDir.canonicalPath}" - ) - None - } - end ClassFile - - /** Find
.class
files, including those inside archives and copy them to their package - * path location relative to [[destDir]] - * - * @param src - * The file/directory to search. + private def getPackagePathFromByteCode(file: File): Option[String] = + Try(file.fileInputStream.apply(getPackagePathFromByteCode)) + .recover { case e: Throwable => + logger.debug(s"Error reading class file ${file.canonicalPath}", e) + None + } + .getOrElse(None) + end ClassFile + sealed class ClassFile(val file: File, val packagePath: Option[String]): + def this(file: File) = this(file, ClassFile.getPackagePathFromByteCode(file)) + + private val components: Option[Array[String]] = packagePath.map(_.split("/")) + + val fullyQualifiedClassName: Option[String] = components.map(_.mkString(".")) + + /** Copy the class file to its package path relative to [[destDir]]. This will overwrite a class + * file at the destination if it exists. * @param destDir - * The directory in which to place the class files - * @param isArchive - * Whether an entry is an archive to extract - * @param isClass - * Whether an entry is a class file - * @param recurse - * Whether to unpack recursively + * The directory in which to place the class file * @return - * The copied class files in destDir + * The class file at the destination if the package path could be retrieved from the its + * bytecode */ - def extractClassesInPackageLayout( - src: File, - destDir: File, - isClass: Entry => Boolean, - isArchive: Entry => Boolean, - recurse: Boolean - ): List[ClassFile] = - File - .temporaryDirectory("extract-classes-") - .apply(tmpDir => - extractClassesToTmp(src, tmpDir, isArchive, isClass, recurse: Boolean).iterator - .flatMap(_.copyToPackageLayoutIn(destDir)) - .toList - ) + def copyToPackageLayoutIn(destDir: File): Option[ClassFile] = + packagePath + .map { path => + val destClass = destDir / s"$path.class" + if destClass.exists() then + logger.warn(s"Overwriting class file: ${destClass.path.toAbsolutePath}") + destClass.parent.createDirectories() + ClassFile( + file.copyTo(destClass)(File.CopyOptions(overwrite = true)), + packagePath + ) + } + .orElse { + logger.warn( + s"Missing package path for ${file.canonicalPath}. Failed to copy to ${destDir.canonicalPath}" + ) + None + } + end ClassFile + + /** Find
.class
files, including those inside archives and copy them to their package + * path location relative to [[destDir]] + * + * @param src + * The file/directory to search. + * @param destDir + * The directory in which to place the class files + * @param isArchive + * Whether an entry is an archive to extract + * @param isClass + * Whether an entry is a class file + * @param recurse + * Whether to unpack recursively + * @return + * The copied class files in destDir + */ + def extractClassesInPackageLayout( + src: File, + destDir: File, + isClass: Entry => Boolean, + isArchive: Entry => Boolean, + recurse: Boolean + ): List[ClassFile] = + File + .temporaryDirectory("extract-classes-") + .apply(tmpDir => + extractClassesToTmp(src, tmpDir, isArchive, isClass, recurse: Boolean).iterator + .flatMap(_.copyToPackageLayoutIn(destDir)) + .toList + ) end ProgramHandlingUtil diff --git a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/JsSrc2Cpg.scala b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/JsSrc2Cpg.scala index 1fcf26c3..b4fe6d70 100644 --- a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/JsSrc2Cpg.scala +++ b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/JsSrc2Cpg.scala @@ -32,52 +32,52 @@ import scala.util.Try class JsSrc2Cpg extends X2CpgFrontend[Config]: - private val report: Report = new Report() + private val report: Report = new Report() - def createCpg(config: Config): Try[Cpg] = - withNewEmptyCpg(config.outputPath, config) { (cpg, config) => - File.usingTemporaryDirectory("jssrc2cpgOut") { tmpDir => - val astGenResult = new AstGenRunner(config).execute(tmpDir) - val hash = HashUtil.sha256(astGenResult.parsedFiles.map { case (_, file) => - File(file).path - }) + def createCpg(config: Config): Try[Cpg] = + withNewEmptyCpg(config.outputPath, config) { (cpg, config) => + File.usingTemporaryDirectory("jssrc2cpgOut") { tmpDir => + val astGenResult = new AstGenRunner(config).execute(tmpDir) + val hash = HashUtil.sha256(astGenResult.parsedFiles.map { case (_, file) => + File(file).path + }) - val astCreationPass = - new AstCreationPass(cpg, astGenResult, config, report)(config.schemaValidation) - astCreationPass.createAndApply() + val astCreationPass = + new AstCreationPass(cpg, astGenResult, config, report)(config.schemaValidation) + astCreationPass.createAndApply() - new TypeNodePass(astCreationPass.allUsedTypes(), cpg).createAndApply() - new JsMetaDataPass(cpg, hash, config.inputPath).createAndApply() - new BuiltinTypesPass(cpg).createAndApply() - new DependenciesPass(cpg, config).createAndApply() - new ConfigPass(cpg, config, report).createAndApply() - new PrivateKeyFilePass(cpg, config, report).createAndApply() - new ImportsPass(cpg).createAndApply() + new TypeNodePass(astCreationPass.allUsedTypes(), cpg).createAndApply() + new JsMetaDataPass(cpg, hash, config.inputPath).createAndApply() + new BuiltinTypesPass(cpg).createAndApply() + new DependenciesPass(cpg, config).createAndApply() + new ConfigPass(cpg, config, report).createAndApply() + new PrivateKeyFilePass(cpg, config, report).createAndApply() + new ImportsPass(cpg).createAndApply() - report.print() - } - } + report.print() + } + } - // This method is intended for internal use only and may be removed at any time. - def createCpgWithAllOverlays(config: Config): Try[Cpg] = - val maybeCpg = createCpgWithOverlays(config) - maybeCpg.map { cpg => - new OssDataFlow(new OssDataFlowOptions()).run(new LayerCreatorContext(cpg)) - postProcessingPasses(cpg, Option(config)).foreach(_.createAndApply()) - cpg - } + // This method is intended for internal use only and may be removed at any time. + def createCpgWithAllOverlays(config: Config): Try[Cpg] = + val maybeCpg = createCpgWithOverlays(config) + maybeCpg.map { cpg => + new OssDataFlow(new OssDataFlowOptions()).run(new LayerCreatorContext(cpg)) + postProcessingPasses(cpg, Option(config)).foreach(_.createAndApply()) + cpg + } end JsSrc2Cpg object JsSrc2Cpg: - def postProcessingPasses(cpg: Cpg, config: Option[Config] = None): List[CpgPassBase] = - val typeRecoveryConfig = config - .map(c => XTypeRecoveryConfig(c.typePropagationIterations, !c.disableDummyTypes)) - .getOrElse(XTypeRecoveryConfig()) - List( - new JavaScriptInheritanceNamePass(cpg), - new ConstClosurePass(cpg), - new ImportResolverPass(cpg), - new JavaScriptTypeRecoveryPass(cpg, typeRecoveryConfig), - new JavaScriptTypeHintCallLinker(cpg) - ) + def postProcessingPasses(cpg: Cpg, config: Option[Config] = None): List[CpgPassBase] = + val typeRecoveryConfig = config + .map(c => XTypeRecoveryConfig(c.typePropagationIterations, !c.disableDummyTypes)) + .getOrElse(XTypeRecoveryConfig()) + List( + new JavaScriptInheritanceNamePass(cpg), + new ConstClosurePass(cpg), + new ImportResolverPass(cpg), + new JavaScriptTypeRecoveryPass(cpg, typeRecoveryConfig), + new JavaScriptTypeHintCallLinker(cpg) + ) diff --git a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/Main.scala b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/Main.scala index 09ac2d94..0af7c60f 100644 --- a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/Main.scala +++ b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/Main.scala @@ -11,29 +11,29 @@ import java.nio.file.Paths final case class Config(tsTypes: Boolean = true) extends X2CpgConfig[Config] with TypeRecoveryParserConfig[Config]: - def withTsTypes(value: Boolean): Config = - copy(tsTypes = value).withInheritedFields(this) + def withTsTypes(value: Boolean): Config = + copy(tsTypes = value).withInheritedFields(this) object Frontend: - implicit val defaultConfig: Config = Config() - - val cmdLineParser: OParser[Unit, Config] = - val builder = OParser.builder[Config] - import builder.* - OParser.sequence( - programName("jssrc2cpg"), - opt[Unit]("no-tsTypes") - .hidden() - .action((_, c) => c.withTsTypes(false)) - .text("disable generation of types via Typescript"), - XTypeRecovery.parserOptions - ) + implicit val defaultConfig: Config = Config() + + val cmdLineParser: OParser[Unit, Config] = + val builder = OParser.builder[Config] + import builder.* + OParser.sequence( + programName("jssrc2cpg"), + opt[Unit]("no-tsTypes") + .hidden() + .action((_, c) => c.withTsTypes(false)) + .text("disable generation of types via Typescript"), + XTypeRecovery.parserOptions + ) object Main extends X2CpgMain(cmdLineParser, new JsSrc2Cpg()): - def run(config: Config, jssrc2cpg: JsSrc2Cpg): Unit = - val absPath = Paths.get(config.inputPath).toAbsolutePath.toString - if Environment.pathExists(absPath) then - jssrc2cpg.run(config.withInputPath(absPath)) - else - System.exit(1) + def run(config: Config, jssrc2cpg: JsSrc2Cpg): Unit = + val absPath = Paths.get(config.inputPath).toAbsolutePath.toString + if Environment.pathExists(absPath) then + jssrc2cpg.run(config.withInputPath(absPath)) + else + System.exit(1) diff --git a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/astcreation/AstCreator.scala b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/astcreation/AstCreator.scala index 68109d7e..77660030 100644 --- a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/astcreation/AstCreator.scala +++ b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/astcreation/AstCreator.scala @@ -46,233 +46,233 @@ class AstCreator( with AstCreatorHelper with X2CpgAstNodeBuilder[BabelNodeInfo, AstCreator]: - protected val logger: Logger = LoggerFactory.getLogger(classOf[AstCreator]) + protected val logger: Logger = LoggerFactory.getLogger(classOf[AstCreator]) - protected val scope = new Scope() + protected val scope = new Scope() - // TypeDecls with their bindings (with their refs) for lambdas and methods are not put in the AST - // where the respective nodes are defined. Instead we put them under the parent TYPE_DECL in which they are defined. - // To achieve this we need this extra stack. - protected val methodAstParentStack = new Stack[NewNode]() - protected val typeRefIdStack = new Stack[NewTypeRef] - protected val dynamicInstanceTypeStack = new Stack[String] - protected val localAstParentStack = new Stack[NewBlock]() - protected val rootTypeDecl = new Stack[NewTypeDecl]() - protected val typeFullNameToPostfix = mutable.HashMap.empty[String, Int] - protected val functionNodeToNameAndFullName = - mutable.HashMap.empty[BabelNodeInfo, (String, String)] - protected val usedVariableNames = mutable.HashMap.empty[String, Int] - protected val seenAliasTypes = mutable.HashSet.empty[NewTypeDecl] - protected val functionFullNames = mutable.HashSet.empty[String] + // TypeDecls with their bindings (with their refs) for lambdas and methods are not put in the AST + // where the respective nodes are defined. Instead we put them under the parent TYPE_DECL in which they are defined. + // To achieve this we need this extra stack. + protected val methodAstParentStack = new Stack[NewNode]() + protected val typeRefIdStack = new Stack[NewTypeRef] + protected val dynamicInstanceTypeStack = new Stack[String] + protected val localAstParentStack = new Stack[NewBlock]() + protected val rootTypeDecl = new Stack[NewTypeDecl]() + protected val typeFullNameToPostfix = mutable.HashMap.empty[String, Int] + protected val functionNodeToNameAndFullName = + mutable.HashMap.empty[BabelNodeInfo, (String, String)] + protected val usedVariableNames = mutable.HashMap.empty[String, Int] + protected val seenAliasTypes = mutable.HashSet.empty[NewTypeDecl] + protected val functionFullNames = mutable.HashSet.empty[String] - // we track line and column numbers manually because astgen / @babel-parser sometimes - // fails to deliver them at all - strange, but this even happens with its latest version - protected val (positionToLineNumberMapping, positionToFirstPositionInLineMapping) = - positionLookupTables(parserResult.fileContent) + // we track line and column numbers manually because astgen / @babel-parser sometimes + // fails to deliver them at all - strange, but this even happens with its latest version + protected val (positionToLineNumberMapping, positionToFirstPositionInLineMapping) = + positionLookupTables(parserResult.fileContent) - override def createAst(): DiffGraphBuilder = - val fileNode = NewFile().name(parserResult.filename).order(1) - val namespaceBlock = globalNamespaceBlock() - methodAstParentStack.push(namespaceBlock) - val ast = Ast(fileNode).withChild(Ast(namespaceBlock).withChild(createProgramMethod())) - Ast.storeInDiffGraph(ast, diffGraph) - createVariableReferenceLinks() - diffGraph + override def createAst(): DiffGraphBuilder = + val fileNode = NewFile().name(parserResult.filename).order(1) + val namespaceBlock = globalNamespaceBlock() + methodAstParentStack.push(namespaceBlock) + val ast = Ast(fileNode).withChild(Ast(namespaceBlock).withChild(createProgramMethod())) + Ast.storeInDiffGraph(ast, diffGraph) + createVariableReferenceLinks() + diffGraph - private def createProgramMethod(): Ast = - val path = parserResult.filename - val astNodeInfo = createBabelNodeInfo(parserResult.json("ast")) - val lineNumber = astNodeInfo.lineNumber - val columnNumber = astNodeInfo.columnNumber - val lineNumberEnd = astNodeInfo.lineNumberEnd - val columnNumberEnd = astNodeInfo.columnNumberEnd - val name = ":program" - val fullName = s"$path:$name" + private def createProgramMethod(): Ast = + val path = parserResult.filename + val astNodeInfo = createBabelNodeInfo(parserResult.json("ast")) + val lineNumber = astNodeInfo.lineNumber + val columnNumber = astNodeInfo.columnNumber + val lineNumberEnd = astNodeInfo.lineNumberEnd + val columnNumberEnd = astNodeInfo.columnNumberEnd + val name = ":program" + val fullName = s"$path:$name" - val programMethod = - NewMethod() - .order(1) - .name(name) - .code(name) - .fullName(fullName) - .filename(path) - .lineNumber(lineNumber) - .lineNumberEnd(lineNumberEnd) - .columnNumber(columnNumber) - .columnNumberEnd(columnNumberEnd) - .astParentType(NodeTypes.TYPE_DECL) - .astParentFullName(fullName) + val programMethod = + NewMethod() + .order(1) + .name(name) + .code(name) + .fullName(fullName) + .filename(path) + .lineNumber(lineNumber) + .lineNumberEnd(lineNumberEnd) + .columnNumber(columnNumber) + .columnNumberEnd(columnNumberEnd) + .astParentType(NodeTypes.TYPE_DECL) + .astParentFullName(fullName) - val functionTypeAndTypeDeclAst = - createFunctionTypeAndTypeDeclAst( - astNodeInfo, - programMethod, - methodAstParentStack.head, - name, - fullName, - path - ) - rootTypeDecl.push(functionTypeAndTypeDeclAst.nodes.head.asInstanceOf[NewTypeDecl]) + val functionTypeAndTypeDeclAst = + createFunctionTypeAndTypeDeclAst( + astNodeInfo, + programMethod, + methodAstParentStack.head, + name, + fullName, + path + ) + rootTypeDecl.push(functionTypeAndTypeDeclAst.nodes.head.asInstanceOf[NewTypeDecl]) - methodAstParentStack.push(programMethod) + methodAstParentStack.push(programMethod) - val blockNode = NewBlock().typeFullName(Defines.Any) + val blockNode = NewBlock().typeFullName(Defines.Any) - scope.pushNewMethodScope(fullName, name, blockNode, None) - localAstParentStack.push(blockNode) + scope.pushNewMethodScope(fullName, name, blockNode, None) + localAstParentStack.push(blockNode) - val thisParam = - parameterInNode(astNodeInfo, "this", "this", 0, false, EvaluationStrategies.BY_VALUE) - .dynamicTypeHintFullName(typeHintForThisExpression()) - scope.addVariable("this", thisParam, MethodScope) + val thisParam = + parameterInNode(astNodeInfo, "this", "this", 0, false, EvaluationStrategies.BY_VALUE) + .dynamicTypeHintFullName(typeHintForThisExpression()) + scope.addVariable("this", thisParam, MethodScope) - val methodChildren = astsForFile(astNodeInfo) - setArgumentIndices(methodChildren) + val methodChildren = astsForFile(astNodeInfo) + setArgumentIndices(methodChildren) - val methodReturn = newMethodReturnNode(Defines.Any, line = None, column = None) + val methodReturn = newMethodReturnNode(Defines.Any, line = None, column = None) - localAstParentStack.pop() - scope.popScope() - methodAstParentStack.pop() + localAstParentStack.pop() + scope.popScope() + methodAstParentStack.pop() - functionTypeAndTypeDeclAst.withChild( - methodAst( - programMethod, - List(Ast(thisParam)), - blockAst(blockNode, methodChildren), - methodReturn - ) - ) - end createProgramMethod + functionTypeAndTypeDeclAst.withChild( + methodAst( + programMethod, + List(Ast(thisParam)), + blockAst(blockNode, methodChildren), + methodReturn + ) + ) + end createProgramMethod - protected def astForNode(json: Value): Ast = - val nodeInfo = createBabelNodeInfo(json) - nodeInfo.node match - case ClassDeclaration => astForClass(nodeInfo, shouldCreateAssignmentCall = true) - case DeclareClass => astForClass(nodeInfo, shouldCreateAssignmentCall = true) - case ClassExpression => astForClass(nodeInfo) - case TSInterfaceDeclaration => astForInterface(nodeInfo) - case TSModuleDeclaration => astForModule(nodeInfo) - case TSExportAssignment => astForExportAssignment(nodeInfo) - case ExportNamedDeclaration => astForExportNamedDeclaration(nodeInfo) - case ExportDefaultDeclaration => astForExportDefaultDeclaration(nodeInfo) - case ExportAllDeclaration => astForExportAllDeclaration(nodeInfo) - case ImportDeclaration => astForImportDeclaration(nodeInfo) - case FunctionDeclaration => astForFunctionDeclaration(nodeInfo) - case TSDeclareFunction => astForTSDeclareFunction(nodeInfo) - case VariableDeclaration => astForVariableDeclaration(nodeInfo) - case ArrowFunctionExpression => astForFunctionDeclaration(nodeInfo) - case FunctionExpression => astForFunctionDeclaration(nodeInfo) - case TSEnumDeclaration => astForEnum(nodeInfo) - case DeclareTypeAlias => astForTypeAlias(nodeInfo) - case TypeAlias => astForTypeAlias(nodeInfo) - case TypeCastExpression => astForCastExpression(nodeInfo) - case TSTypeAssertion => astForCastExpression(nodeInfo) - case TSTypeCastExpression => astForCastExpression(nodeInfo) - case TSTypeAliasDeclaration => astForTypeAlias(nodeInfo) - case NewExpression => astForNewExpression(nodeInfo) - case ThisExpression => astForThisExpression(nodeInfo) - case MemberExpression => astForMemberExpression(nodeInfo) - case OptionalMemberExpression => astForMemberExpression(nodeInfo) - case MetaProperty => astForMetaProperty(nodeInfo) - case CallExpression => astForCallExpression(nodeInfo) - case OptionalCallExpression => astForCallExpression(nodeInfo) - case SequenceExpression => astForSequenceExpression(nodeInfo) - case AssignmentExpression => astForAssignmentExpression(nodeInfo) - case AssignmentPattern => astForAssignmentExpression(nodeInfo) - case BinaryExpression => astForBinaryExpression(nodeInfo) - case LogicalExpression => astForLogicalExpression(nodeInfo) - case TSAsExpression => astForCastExpression(nodeInfo) - case UpdateExpression => astForUpdateExpression(nodeInfo) - case UnaryExpression => astForUnaryExpression(nodeInfo) - case ArrayExpression => astForArrayExpression(nodeInfo) - case AwaitExpression => astForAwaitExpression(nodeInfo) - case ConditionalExpression => astForConditionalExpression(nodeInfo) - case TaggedTemplateExpression => astForTemplateExpression(nodeInfo) - case ObjectExpression => astForObjectExpression(nodeInfo) - case TSNonNullExpression => astForTSNonNullExpression(nodeInfo) - case YieldExpression => astForReturnStatement(nodeInfo) - case ExpressionStatement => astForExpressionStatement(nodeInfo) - case IfStatement => astForIfStatement(nodeInfo) - case BlockStatement => astForBlockStatement(nodeInfo) - case ReturnStatement => astForReturnStatement(nodeInfo) - case TryStatement => astForTryStatement(nodeInfo) - case ForStatement => astForForStatement(nodeInfo) - case WhileStatement => astForWhileStatement(nodeInfo) - case DoWhileStatement => astForDoWhileStatement(nodeInfo) - case SwitchStatement => astForSwitchStatement(nodeInfo) - case BreakStatement => astForBreakStatement(nodeInfo) - case ContinueStatement => astForContinueStatement(nodeInfo) - case LabeledStatement => astForLabeledStatement(nodeInfo) - case ThrowStatement => astForThrowStatement(nodeInfo) - case ForInStatement => astForInOfStatement(nodeInfo) - case ForOfStatement => astForInOfStatement(nodeInfo) - case ObjectPattern => astForObjectExpression(nodeInfo) - case ArrayPattern => astForArrayExpression(nodeInfo) - case Identifier => astForIdentifier(nodeInfo) - case PrivateName => astForPrivateName(nodeInfo) - case Super => astForSuperKeyword(nodeInfo) - case Import => astForImportKeyword(nodeInfo) - case TSImportEqualsDeclaration => astForTSImportEqualsDeclaration(nodeInfo) - case StringLiteral => astForStringLiteral(nodeInfo) - case NumericLiteral => astForNumericLiteral(nodeInfo) - case NumberLiteral => astForNumberLiteral(nodeInfo) - case DecimalLiteral => astForDecimalLiteral(nodeInfo) - case NullLiteral => astForNullLiteral(nodeInfo) - case BooleanLiteral => astForBooleanLiteral(nodeInfo) - case RegExpLiteral => astForRegExpLiteral(nodeInfo) - case RegexLiteral => astForRegexLiteral(nodeInfo) - case BigIntLiteral => astForBigIntLiteral(nodeInfo) - case TemplateLiteral => astForTemplateLiteral(nodeInfo) - case TemplateElement => astForTemplateElement(nodeInfo) - case SpreadElement => astForSpreadOrRestElement(nodeInfo) - case TSSatisfiesExpression => astForTSSatisfiesExpression(nodeInfo) - case JSXElement => astForJsxElement(nodeInfo) - case JSXOpeningElement => astForJsxOpeningElement(nodeInfo) - case JSXClosingElement => astForJsxClosingElement(nodeInfo) - case JSXText => astForJsxText(nodeInfo) - case JSXExpressionContainer => astForJsxExprContainer(nodeInfo) - case JSXSpreadChild => astForJsxExprContainer(nodeInfo) - case JSXSpreadAttribute => astForJsxSpreadAttribute(nodeInfo) - case JSXFragment => astForJsxFragment(nodeInfo) - case JSXAttribute => astForJsxAttribute(nodeInfo) - case WithStatement => astForWithStatement(nodeInfo) - case EmptyStatement => Ast() - case DebuggerStatement => Ast() - case _ => notHandledYet(nodeInfo) - end match - end astForNode + protected def astForNode(json: Value): Ast = + val nodeInfo = createBabelNodeInfo(json) + nodeInfo.node match + case ClassDeclaration => astForClass(nodeInfo, shouldCreateAssignmentCall = true) + case DeclareClass => astForClass(nodeInfo, shouldCreateAssignmentCall = true) + case ClassExpression => astForClass(nodeInfo) + case TSInterfaceDeclaration => astForInterface(nodeInfo) + case TSModuleDeclaration => astForModule(nodeInfo) + case TSExportAssignment => astForExportAssignment(nodeInfo) + case ExportNamedDeclaration => astForExportNamedDeclaration(nodeInfo) + case ExportDefaultDeclaration => astForExportDefaultDeclaration(nodeInfo) + case ExportAllDeclaration => astForExportAllDeclaration(nodeInfo) + case ImportDeclaration => astForImportDeclaration(nodeInfo) + case FunctionDeclaration => astForFunctionDeclaration(nodeInfo) + case TSDeclareFunction => astForTSDeclareFunction(nodeInfo) + case VariableDeclaration => astForVariableDeclaration(nodeInfo) + case ArrowFunctionExpression => astForFunctionDeclaration(nodeInfo) + case FunctionExpression => astForFunctionDeclaration(nodeInfo) + case TSEnumDeclaration => astForEnum(nodeInfo) + case DeclareTypeAlias => astForTypeAlias(nodeInfo) + case TypeAlias => astForTypeAlias(nodeInfo) + case TypeCastExpression => astForCastExpression(nodeInfo) + case TSTypeAssertion => astForCastExpression(nodeInfo) + case TSTypeCastExpression => astForCastExpression(nodeInfo) + case TSTypeAliasDeclaration => astForTypeAlias(nodeInfo) + case NewExpression => astForNewExpression(nodeInfo) + case ThisExpression => astForThisExpression(nodeInfo) + case MemberExpression => astForMemberExpression(nodeInfo) + case OptionalMemberExpression => astForMemberExpression(nodeInfo) + case MetaProperty => astForMetaProperty(nodeInfo) + case CallExpression => astForCallExpression(nodeInfo) + case OptionalCallExpression => astForCallExpression(nodeInfo) + case SequenceExpression => astForSequenceExpression(nodeInfo) + case AssignmentExpression => astForAssignmentExpression(nodeInfo) + case AssignmentPattern => astForAssignmentExpression(nodeInfo) + case BinaryExpression => astForBinaryExpression(nodeInfo) + case LogicalExpression => astForLogicalExpression(nodeInfo) + case TSAsExpression => astForCastExpression(nodeInfo) + case UpdateExpression => astForUpdateExpression(nodeInfo) + case UnaryExpression => astForUnaryExpression(nodeInfo) + case ArrayExpression => astForArrayExpression(nodeInfo) + case AwaitExpression => astForAwaitExpression(nodeInfo) + case ConditionalExpression => astForConditionalExpression(nodeInfo) + case TaggedTemplateExpression => astForTemplateExpression(nodeInfo) + case ObjectExpression => astForObjectExpression(nodeInfo) + case TSNonNullExpression => astForTSNonNullExpression(nodeInfo) + case YieldExpression => astForReturnStatement(nodeInfo) + case ExpressionStatement => astForExpressionStatement(nodeInfo) + case IfStatement => astForIfStatement(nodeInfo) + case BlockStatement => astForBlockStatement(nodeInfo) + case ReturnStatement => astForReturnStatement(nodeInfo) + case TryStatement => astForTryStatement(nodeInfo) + case ForStatement => astForForStatement(nodeInfo) + case WhileStatement => astForWhileStatement(nodeInfo) + case DoWhileStatement => astForDoWhileStatement(nodeInfo) + case SwitchStatement => astForSwitchStatement(nodeInfo) + case BreakStatement => astForBreakStatement(nodeInfo) + case ContinueStatement => astForContinueStatement(nodeInfo) + case LabeledStatement => astForLabeledStatement(nodeInfo) + case ThrowStatement => astForThrowStatement(nodeInfo) + case ForInStatement => astForInOfStatement(nodeInfo) + case ForOfStatement => astForInOfStatement(nodeInfo) + case ObjectPattern => astForObjectExpression(nodeInfo) + case ArrayPattern => astForArrayExpression(nodeInfo) + case Identifier => astForIdentifier(nodeInfo) + case PrivateName => astForPrivateName(nodeInfo) + case Super => astForSuperKeyword(nodeInfo) + case Import => astForImportKeyword(nodeInfo) + case TSImportEqualsDeclaration => astForTSImportEqualsDeclaration(nodeInfo) + case StringLiteral => astForStringLiteral(nodeInfo) + case NumericLiteral => astForNumericLiteral(nodeInfo) + case NumberLiteral => astForNumberLiteral(nodeInfo) + case DecimalLiteral => astForDecimalLiteral(nodeInfo) + case NullLiteral => astForNullLiteral(nodeInfo) + case BooleanLiteral => astForBooleanLiteral(nodeInfo) + case RegExpLiteral => astForRegExpLiteral(nodeInfo) + case RegexLiteral => astForRegexLiteral(nodeInfo) + case BigIntLiteral => astForBigIntLiteral(nodeInfo) + case TemplateLiteral => astForTemplateLiteral(nodeInfo) + case TemplateElement => astForTemplateElement(nodeInfo) + case SpreadElement => astForSpreadOrRestElement(nodeInfo) + case TSSatisfiesExpression => astForTSSatisfiesExpression(nodeInfo) + case JSXElement => astForJsxElement(nodeInfo) + case JSXOpeningElement => astForJsxOpeningElement(nodeInfo) + case JSXClosingElement => astForJsxClosingElement(nodeInfo) + case JSXText => astForJsxText(nodeInfo) + case JSXExpressionContainer => astForJsxExprContainer(nodeInfo) + case JSXSpreadChild => astForJsxExprContainer(nodeInfo) + case JSXSpreadAttribute => astForJsxSpreadAttribute(nodeInfo) + case JSXFragment => astForJsxFragment(nodeInfo) + case JSXAttribute => astForJsxAttribute(nodeInfo) + case WithStatement => astForWithStatement(nodeInfo) + case EmptyStatement => Ast() + case DebuggerStatement => Ast() + case _ => notHandledYet(nodeInfo) + end match + end astForNode - protected def astForNodeWithFunctionReference(json: Value): Ast = - val nodeInfo = createBabelNodeInfo(json) - nodeInfo.node match - case _: FunctionLike => - astForFunctionDeclaration(nodeInfo, shouldCreateFunctionReference = true) - case _ => astForNode(json) + protected def astForNodeWithFunctionReference(json: Value): Ast = + val nodeInfo = createBabelNodeInfo(json) + nodeInfo.node match + case _: FunctionLike => + astForFunctionDeclaration(nodeInfo, shouldCreateFunctionReference = true) + case _ => astForNode(json) - protected def astForNodeWithFunctionReferenceAndCall(json: Value): Ast = - val nodeInfo = createBabelNodeInfo(json) - nodeInfo.node match - case _: FunctionLike => - astForFunctionDeclaration( - nodeInfo, - shouldCreateFunctionReference = true, - shouldCreateAssignmentCall = true - ) - case _ => astForNode(json) + protected def astForNodeWithFunctionReferenceAndCall(json: Value): Ast = + val nodeInfo = createBabelNodeInfo(json) + nodeInfo.node match + case _: FunctionLike => + astForFunctionDeclaration( + nodeInfo, + shouldCreateFunctionReference = true, + shouldCreateAssignmentCall = true + ) + case _ => astForNode(json) - protected def astForNodes(jsons: List[Value]): List[Ast] = - jsons.map(astForNodeWithFunctionReference) + protected def astForNodes(jsons: List[Value]): List[Ast] = + jsons.map(astForNodeWithFunctionReference) - private def astsForFile(file: BabelNodeInfo): List[Ast] = - astsForProgram(createBabelNodeInfo(file.json("program"))) + private def astsForFile(file: BabelNodeInfo): List[Ast] = + astsForProgram(createBabelNodeInfo(file.json("program"))) - private def astsForProgram(program: BabelNodeInfo): List[Ast] = - createBlockStatementAsts(program.json("body")) + private def astsForProgram(program: BabelNodeInfo): List[Ast] = + createBlockStatementAsts(program.json("body")) - protected def line(node: BabelNodeInfo): Option[Integer] = node.lineNumber - protected def column(node: BabelNodeInfo): Option[Integer] = node.columnNumber - protected def lineEnd(node: BabelNodeInfo): Option[Integer] = node.lineNumberEnd - protected def columnEnd(node: BabelNodeInfo): Option[Integer] = node.columnNumberEnd - protected def code(node: BabelNodeInfo): String = node.code + protected def line(node: BabelNodeInfo): Option[Integer] = node.lineNumber + protected def column(node: BabelNodeInfo): Option[Integer] = node.columnNumber + protected def lineEnd(node: BabelNodeInfo): Option[Integer] = node.lineNumberEnd + protected def columnEnd(node: BabelNodeInfo): Option[Integer] = node.columnNumberEnd + protected def code(node: BabelNodeInfo): String = node.code end AstCreator diff --git a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/astcreation/AstCreatorHelper.scala b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/astcreation/AstCreatorHelper.scala index d7ae83ae..de9ded20 100644 --- a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/astcreation/AstCreatorHelper.scala +++ b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/astcreation/AstCreatorHelper.scala @@ -32,298 +32,298 @@ import scala.util.Success import scala.util.Try trait AstCreatorHelper(implicit withSchemaValidation: ValidationMode): - this: AstCreator => - - // maximum length of code fields in number of characters - private val MaxCodeLength: Int = 1000 - private val MinCodeLength: Int = 50 - - protected def createBabelNodeInfo(json: Value): BabelNodeInfo = - val c = shortenCode(code(json)) - val ln = line(json) - val cn = column(json) - val lnEnd = lineEnd(json) - val cnEnd = columnEnd(json) - val node = nodeType(json) - BabelNodeInfo(node, json, c, ln, cn, lnEnd, cnEnd) - - protected def notHandledYet(node: BabelNodeInfo): Ast = - val text = - s"""Node type '${node.node}' not handled yet! + this: AstCreator => + + // maximum length of code fields in number of characters + private val MaxCodeLength: Int = 1000 + private val MinCodeLength: Int = 50 + + protected def createBabelNodeInfo(json: Value): BabelNodeInfo = + val c = shortenCode(code(json)) + val ln = line(json) + val cn = column(json) + val lnEnd = lineEnd(json) + val cnEnd = columnEnd(json) + val node = nodeType(json) + BabelNodeInfo(node, json, c, ln, cn, lnEnd, cnEnd) + + protected def notHandledYet(node: BabelNodeInfo): Ast = + val text = + s"""Node type '${node.node}' not handled yet! | Code: '${shortenCode(node.code, length = 50)}' | File: '${parserResult.fullPath}' | Line: ${node.lineNumber.getOrElse(-1)} | Column: ${node.columnNumber.getOrElse(-1)} | """.stripMargin - logger.debug(text) - Ast(unknownNode(node, node.code)) - - protected def registerType(typeName: String, typeFullName: String): Unit = - if usedTypes.containsKey((typeName, typeName)) && typeName != typeFullName then - usedTypes.put((typeName, typeFullName), true) - usedTypes.remove((typeName, typeName)) - else if !usedTypes.keys().asScala.exists { case (tpn, _) => tpn == typeName } then - usedTypes.putIfAbsent((typeName, typeFullName), true) - - private def nodeType(node: Value): BabelNode = fromString(node("type").str) - - protected def codeForNodes(nodes: Seq[NewNode]): Option[String] = nodes.collectFirst { - case id: NewIdentifier => id.name.replace("...", "") - case clazz: NewTypeRef => clazz.code.stripPrefix("class ") - } - - protected def nameForBabelNodeInfo( - nodeInfo: BabelNodeInfo, - defaultName: Option[String] - ): String = - defaultName - .orElse(codeForBabelNodeInfo(nodeInfo).headOption) - .getOrElse { - val tmpName = generateUnusedVariableName(usedVariableNames, "_tmp") - val localNode = newLocalNode(tmpName, Defines.Any).order(0) - diffGraph.addEdge(localAstParentStack.head, localNode, EdgeTypes.AST) - tmpName - } - - protected def generateUnusedVariableName( - usedVariableNames: mutable.HashMap[String, Int], - variableName: String - ): String = - val counter = usedVariableNames.get(variableName).map(_ + 1).getOrElse(0) - val currentVariableName = s"${variableName}_$counter" - usedVariableNames.put(variableName, counter) - currentVariableName - - protected def code(node: Value): String = - val startIndex = start(node).getOrElse(0) - val endIndex = Math.min(end(node).getOrElse(0), parserResult.fileContent.length) - parserResult.fileContent.substring(startIndex, endIndex).trim - - private def shortenCode(code: String, length: Int = MaxCodeLength): String = - StringUtils.abbreviate(code, math.max(MinCodeLength, length)) - - protected def hasKey(node: Value, key: String): Boolean = Try(node(key)).isSuccess - - protected def safeStr(node: Value, key: String): Option[String] = - if hasKey(node, key) then Try(node(key).str).toOption else None - - protected def safeBool(node: Value, key: String): Option[Boolean] = - if hasKey(node, key) then Try(node(key).bool).toOption else None - - protected def safeObj( - node: Value, - key: String - ): Option[upickle.core.LinkedHashMap[String, Value]] = Try( - node(key).obj - ) match - case Success(value) if value.nonEmpty => Option(value) - case _ => None - - private def start(node: Value): Option[Int] = Try(node("start").num.toInt).toOption - - private def end(node: Value): Option[Int] = Try(node("end").num.toInt).toOption - - protected def pos(node: Value): Option[Int] = Try(node("start").num.toInt).toOption - - protected def line(node: Value): Option[Integer] = start(node).map(getLineOfSource) - - protected def lineEnd(node: Value): Option[Integer] = end(node).map(getLineOfSource) - - protected def column(node: Value): Option[Integer] = start(node).map(getColumnOfSource) - - protected def columnEnd(node: Value): Option[Integer] = end(node).map(getColumnOfSource) - - // Returns the line number for a given position in the source. - private def getLineOfSource(position: Int): Int = - val (_, lineNumber) = positionToLineNumberMapping.minAfter(position).get - lineNumber - - // Returns the column number for a given position in the source. - private def getColumnOfSource(position: Int): Int = - val (_, firstPositionInLine) = positionToFirstPositionInLineMapping.minAfter(position).get - position - firstPositionInLine - - protected def positionLookupTables(source: String): (SortedMap[Int, Int], SortedMap[Int, Int]) = - val positionToLineNumber, positionToFirstPositionInLine = mutable.TreeMap.empty[Int, Int] - val data = source.toCharArray - var lineNumber = 1 - var firstPositionInLine = 0 - var position = 0 - while position < data.length do - val isNewLine = data(position) == '\n' - if isNewLine then - positionToLineNumber.put(position, lineNumber) - lineNumber += 1 - positionToFirstPositionInLine.put(position, firstPositionInLine) - firstPositionInLine = position + 1 - position += 1 + logger.debug(text) + Ast(unknownNode(node, node.code)) + + protected def registerType(typeName: String, typeFullName: String): Unit = + if usedTypes.containsKey((typeName, typeName)) && typeName != typeFullName then + usedTypes.put((typeName, typeFullName), true) + usedTypes.remove((typeName, typeName)) + else if !usedTypes.keys().asScala.exists { case (tpn, _) => tpn == typeName } then + usedTypes.putIfAbsent((typeName, typeFullName), true) + + private def nodeType(node: Value): BabelNode = fromString(node("type").str) + + protected def codeForNodes(nodes: Seq[NewNode]): Option[String] = nodes.collectFirst { + case id: NewIdentifier => id.name.replace("...", "") + case clazz: NewTypeRef => clazz.code.stripPrefix("class ") + } + + protected def nameForBabelNodeInfo( + nodeInfo: BabelNodeInfo, + defaultName: Option[String] + ): String = + defaultName + .orElse(codeForBabelNodeInfo(nodeInfo).headOption) + .getOrElse { + val tmpName = generateUnusedVariableName(usedVariableNames, "_tmp") + val localNode = newLocalNode(tmpName, Defines.Any).order(0) + diffGraph.addEdge(localAstParentStack.head, localNode, EdgeTypes.AST) + tmpName + } + + protected def generateUnusedVariableName( + usedVariableNames: mutable.HashMap[String, Int], + variableName: String + ): String = + val counter = usedVariableNames.get(variableName).map(_ + 1).getOrElse(0) + val currentVariableName = s"${variableName}_$counter" + usedVariableNames.put(variableName, counter) + currentVariableName + + protected def code(node: Value): String = + val startIndex = start(node).getOrElse(0) + val endIndex = Math.min(end(node).getOrElse(0), parserResult.fileContent.length) + parserResult.fileContent.substring(startIndex, endIndex).trim + + private def shortenCode(code: String, length: Int = MaxCodeLength): String = + StringUtils.abbreviate(code, math.max(MinCodeLength, length)) + + protected def hasKey(node: Value, key: String): Boolean = Try(node(key)).isSuccess + + protected def safeStr(node: Value, key: String): Option[String] = + if hasKey(node, key) then Try(node(key).str).toOption else None + + protected def safeBool(node: Value, key: String): Option[Boolean] = + if hasKey(node, key) then Try(node(key).bool).toOption else None + + protected def safeObj( + node: Value, + key: String + ): Option[upickle.core.LinkedHashMap[String, Value]] = Try( + node(key).obj + ) match + case Success(value) if value.nonEmpty => Option(value) + case _ => None + + private def start(node: Value): Option[Int] = Try(node("start").num.toInt).toOption + + private def end(node: Value): Option[Int] = Try(node("end").num.toInt).toOption + + protected def pos(node: Value): Option[Int] = Try(node("start").num.toInt).toOption + + protected def line(node: Value): Option[Integer] = start(node).map(getLineOfSource) + + protected def lineEnd(node: Value): Option[Integer] = end(node).map(getLineOfSource) + + protected def column(node: Value): Option[Integer] = start(node).map(getColumnOfSource) + + protected def columnEnd(node: Value): Option[Integer] = end(node).map(getColumnOfSource) + + // Returns the line number for a given position in the source. + private def getLineOfSource(position: Int): Int = + val (_, lineNumber) = positionToLineNumberMapping.minAfter(position).get + lineNumber + + // Returns the column number for a given position in the source. + private def getColumnOfSource(position: Int): Int = + val (_, firstPositionInLine) = positionToFirstPositionInLineMapping.minAfter(position).get + position - firstPositionInLine + + protected def positionLookupTables(source: String): (SortedMap[Int, Int], SortedMap[Int, Int]) = + val positionToLineNumber, positionToFirstPositionInLine = mutable.TreeMap.empty[Int, Int] + val data = source.toCharArray + var lineNumber = 1 + var firstPositionInLine = 0 + var position = 0 + while position < data.length do + val isNewLine = data(position) == '\n' + if isNewLine then positionToLineNumber.put(position, lineNumber) + lineNumber += 1 positionToFirstPositionInLine.put(position, firstPositionInLine) - - // for empty line at the end of each JS/TS file generated by BabelJsonParser: - positionToLineNumber.put(position + 1, lineNumber + 1) - positionToFirstPositionInLine.put(position + 1, 0) - (positionToLineNumber, positionToFirstPositionInLine) - end positionLookupTables - - private def computeScopePath(stack: Option[ScopeElement]): String = - new ScopeElementIterator(stack) - .to(Seq) - .reverse - .collect { case methodScopeElement: MethodScopeElement => methodScopeElement.name } - .mkString(":") - - private def isMethodOrGetSet(func: BabelNodeInfo): Boolean = - if hasKey(func.json, "kind") && !func.json("kind").isNull then - val t = func.json("kind").str - t == "method" || t == "get" || t == "set" - else false - - private def calcMethodName(func: BabelNodeInfo): String = func.node match - case TSCallSignatureDeclaration => "anonymous" - case TSConstructSignatureDeclaration => io.appthreat.x2cpg.Defines.ConstructorMethodName - case _ if isMethodOrGetSet(func) => - if hasKey(func.json("key"), "name") then func.json("key")("name").str - else code(func.json("key")) - case _ if safeStr(func.json, "kind").contains("constructor") => - io.appthreat.x2cpg.Defines.ConstructorMethodName - case _ if func.json("id").isNull => "anonymous" - case _ => func.json("id")("name").str - - protected def calcMethodNameAndFullName(func: BabelNodeInfo): (String, String) = - // functionNode.getName is not necessarily unique and thus the full name calculated based on the scope - // is not necessarily unique. Specifically we have this problem with lambda functions which are defined - // in the same scope. - functionNodeToNameAndFullName.get(func) match - case Some(nameAndFullName) => nameAndFullName - case None => - val intendedName = calcMethodName(func) - val fullNamePrefix = - s"${parserResult.filename}:${computeScopePath(scope.getScopeHead)}:" - var name = intendedName - var fullName = "" - var isUnique = false - var i = 1 - while !isUnique do - fullName = s"$fullNamePrefix$name" - if functionFullNames.contains(fullName) then - name = s"$intendedName$i" - i += 1 - else - isUnique = true - functionFullNames.add(fullName) - functionNodeToNameAndFullName(func) = (name, fullName) - (name, fullName) - - protected def stripQuotes(str: String): String = str - .stripPrefix("\"") - .stripSuffix("\"") - .stripPrefix("'") - .stripSuffix("'") - .stripPrefix("`") - .stripSuffix("`") - - /** In JS it is possible to create anonymous classes. We have to handle this here. - */ - private def calcTypeName(classNode: BabelNodeInfo): String = - if hasKey(classNode.json, "id") && !classNode.json("id").isNull then - code(classNode.json("id")) - else "_anon_cdecl" - - protected def calcTypeNameAndFullName( - classNode: BabelNodeInfo, - preCalculatedName: Option[String] = None - ): (String, String) = - val name = preCalculatedName.getOrElse(calcTypeName(classNode)) - val fullNamePrefix = s"${parserResult.filename}:${computeScopePath(scope.getScopeHead)}:" - val intendedFullName = s"$fullNamePrefix$name" - val postfix = typeFullNameToPostfix.getOrElse(intendedFullName, 0) - val resultingFullName = - if postfix == 0 then intendedFullName - else s"$intendedFullName$postfix" - typeFullNameToPostfix.put(intendedFullName, postfix + 1) - (name, resultingFullName) - - protected def createVariableReferenceLinks(): Unit = - val resolvedReferenceIt = scope.resolve(createMethodLocalForUnresolvedReference) - val capturedLocals = mutable.HashMap.empty[String, NewNode] - - resolvedReferenceIt.foreach { case ResolvedReference(variableNodeId, origin) => - var currentScope = origin.stack - var currentReference = origin.referenceNode - var nextReference: NewNode = null - - var done = false - while !done do - val localOrCapturedLocalNodeOption = - if currentScope.get.nameToVariableNode.contains(origin.variableName) then - done = true - Option(variableNodeId) - else - currentScope.flatMap { - case methodScope: MethodScopeElement - if methodScope.scopeNode.isInstanceOf[NewTypeDecl] || methodScope.scopeNode - .isInstanceOf[NewNamespaceBlock] => - currentScope = - Option(Scope.getEnclosingMethodScopeElement(currentScope)) - None - case methodScope: MethodScopeElement => - // We have reached a MethodScope and still did not find a local variable to link to. - // For all non local references the CPG format does not allow us to link - // directly. Instead we need to create a fake local variable in method - // scope and link to this local which itself carries the information - // that it is a captured variable. This needs to be done for each - // method scope until we reach the originating scope. - val closureBindingIdProperty = - s"${methodScope.methodFullName}:${origin.variableName}" - capturedLocals.updateWith(closureBindingIdProperty) { - case None => - val methodScopeNode = methodScope.scopeNode - val localNode = - newLocalNode( - origin.variableName, - Defines.Any, - Option(closureBindingIdProperty) - ).order(0) - diffGraph.addEdge(methodScopeNode, localNode, EdgeTypes.AST) - val closureBindingNode = newClosureBindingNode( - closureBindingIdProperty, - origin.variableName, - EvaluationStrategies.BY_REFERENCE - ) - methodScope.capturingRefId.foreach(ref => - diffGraph.addEdge( - ref, - closureBindingNode, - EdgeTypes.CAPTURE - ) - ) - nextReference = closureBindingNode - Option(localNode) - case someLocalNode => - // When there is already a LOCAL representing the capturing, we do not - // need to process the surrounding scope element as this has already - // been processed. - done = true - someLocalNode - } - case _: BlockScopeElement => None + firstPositionInLine = position + 1 + position += 1 + positionToLineNumber.put(position, lineNumber) + positionToFirstPositionInLine.put(position, firstPositionInLine) + + // for empty line at the end of each JS/TS file generated by BabelJsonParser: + positionToLineNumber.put(position + 1, lineNumber + 1) + positionToFirstPositionInLine.put(position + 1, 0) + (positionToLineNumber, positionToFirstPositionInLine) + end positionLookupTables + + private def computeScopePath(stack: Option[ScopeElement]): String = + new ScopeElementIterator(stack) + .to(Seq) + .reverse + .collect { case methodScopeElement: MethodScopeElement => methodScopeElement.name } + .mkString(":") + + private def isMethodOrGetSet(func: BabelNodeInfo): Boolean = + if hasKey(func.json, "kind") && !func.json("kind").isNull then + val t = func.json("kind").str + t == "method" || t == "get" || t == "set" + else false + + private def calcMethodName(func: BabelNodeInfo): String = func.node match + case TSCallSignatureDeclaration => "anonymous" + case TSConstructSignatureDeclaration => io.appthreat.x2cpg.Defines.ConstructorMethodName + case _ if isMethodOrGetSet(func) => + if hasKey(func.json("key"), "name") then func.json("key")("name").str + else code(func.json("key")) + case _ if safeStr(func.json, "kind").contains("constructor") => + io.appthreat.x2cpg.Defines.ConstructorMethodName + case _ if func.json("id").isNull => "anonymous" + case _ => func.json("id")("name").str + + protected def calcMethodNameAndFullName(func: BabelNodeInfo): (String, String) = + // functionNode.getName is not necessarily unique and thus the full name calculated based on the scope + // is not necessarily unique. Specifically we have this problem with lambda functions which are defined + // in the same scope. + functionNodeToNameAndFullName.get(func) match + case Some(nameAndFullName) => nameAndFullName + case None => + val intendedName = calcMethodName(func) + val fullNamePrefix = + s"${parserResult.filename}:${computeScopePath(scope.getScopeHead)}:" + var name = intendedName + var fullName = "" + var isUnique = false + var i = 1 + while !isUnique do + fullName = s"$fullNamePrefix$name" + if functionFullNames.contains(fullName) then + name = s"$intendedName$i" + i += 1 + else + isUnique = true + functionFullNames.add(fullName) + functionNodeToNameAndFullName(func) = (name, fullName) + (name, fullName) + + protected def stripQuotes(str: String): String = str + .stripPrefix("\"") + .stripSuffix("\"") + .stripPrefix("'") + .stripSuffix("'") + .stripPrefix("`") + .stripSuffix("`") + + /** In JS it is possible to create anonymous classes. We have to handle this here. + */ + private def calcTypeName(classNode: BabelNodeInfo): String = + if hasKey(classNode.json, "id") && !classNode.json("id").isNull then + code(classNode.json("id")) + else "_anon_cdecl" + + protected def calcTypeNameAndFullName( + classNode: BabelNodeInfo, + preCalculatedName: Option[String] = None + ): (String, String) = + val name = preCalculatedName.getOrElse(calcTypeName(classNode)) + val fullNamePrefix = s"${parserResult.filename}:${computeScopePath(scope.getScopeHead)}:" + val intendedFullName = s"$fullNamePrefix$name" + val postfix = typeFullNameToPostfix.getOrElse(intendedFullName, 0) + val resultingFullName = + if postfix == 0 then intendedFullName + else s"$intendedFullName$postfix" + typeFullNameToPostfix.put(intendedFullName, postfix + 1) + (name, resultingFullName) + + protected def createVariableReferenceLinks(): Unit = + val resolvedReferenceIt = scope.resolve(createMethodLocalForUnresolvedReference) + val capturedLocals = mutable.HashMap.empty[String, NewNode] + + resolvedReferenceIt.foreach { case ResolvedReference(variableNodeId, origin) => + var currentScope = origin.stack + var currentReference = origin.referenceNode + var nextReference: NewNode = null + + var done = false + while !done do + val localOrCapturedLocalNodeOption = + if currentScope.get.nameToVariableNode.contains(origin.variableName) then + done = true + Option(variableNodeId) + else + currentScope.flatMap { + case methodScope: MethodScopeElement + if methodScope.scopeNode.isInstanceOf[NewTypeDecl] || methodScope.scopeNode + .isInstanceOf[NewNamespaceBlock] => + currentScope = + Option(Scope.getEnclosingMethodScopeElement(currentScope)) + None + case methodScope: MethodScopeElement => + // We have reached a MethodScope and still did not find a local variable to link to. + // For all non local references the CPG format does not allow us to link + // directly. Instead we need to create a fake local variable in method + // scope and link to this local which itself carries the information + // that it is a captured variable. This needs to be done for each + // method scope until we reach the originating scope. + val closureBindingIdProperty = + s"${methodScope.methodFullName}:${origin.variableName}" + capturedLocals.updateWith(closureBindingIdProperty) { + case None => + val methodScopeNode = methodScope.scopeNode + val localNode = + newLocalNode( + origin.variableName, + Defines.Any, + Option(closureBindingIdProperty) + ).order(0) + diffGraph.addEdge(methodScopeNode, localNode, EdgeTypes.AST) + val closureBindingNode = newClosureBindingNode( + closureBindingIdProperty, + origin.variableName, + EvaluationStrategies.BY_REFERENCE + ) + methodScope.capturingRefId.foreach(ref => + diffGraph.addEdge( + ref, + closureBindingNode, + EdgeTypes.CAPTURE + ) + ) + nextReference = closureBindingNode + Option(localNode) + case someLocalNode => + // When there is already a LOCAL representing the capturing, we do not + // need to process the surrounding scope element as this has already + // been processed. + done = true + someLocalNode } - - localOrCapturedLocalNodeOption.foreach { localOrCapturedLocalNode => - diffGraph.addEdge(currentReference, localOrCapturedLocalNode, EdgeTypes.REF) - currentReference = nextReference + case _: BlockScopeElement => None } - currentScope = currentScope.get.surroundingScope - end while - } - end createVariableReferenceLinks - - private def createMethodLocalForUnresolvedReference( - methodScopeNodeId: NewNode, - variableName: String - ): (NewNode, ScopeType) = - val local = newLocalNode(variableName, Defines.Any).order(0) - diffGraph.addEdge(methodScopeNodeId, local, EdgeTypes.AST) - (local, MethodScope) + + localOrCapturedLocalNodeOption.foreach { localOrCapturedLocalNode => + diffGraph.addEdge(currentReference, localOrCapturedLocalNode, EdgeTypes.REF) + currentReference = nextReference + } + currentScope = currentScope.get.surroundingScope + end while + } + end createVariableReferenceLinks + + private def createMethodLocalForUnresolvedReference( + methodScopeNodeId: NewNode, + variableName: String + ): (NewNode, ScopeType) = + val local = newLocalNode(variableName, Defines.Any).order(0) + diffGraph.addEdge(methodScopeNodeId, local, EdgeTypes.AST) + (local, MethodScope) end AstCreatorHelper diff --git a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/astcreation/AstForDeclarationsCreator.scala b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/astcreation/AstForDeclarationsCreator.scala index 2a0abf49..8ad8e6a3 100644 --- a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/astcreation/AstForDeclarationsCreator.scala +++ b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/astcreation/AstForDeclarationsCreator.scala @@ -14,854 +14,854 @@ import ujson.Value import scala.util.Try trait AstForDeclarationsCreator(implicit withSchemaValidation: ValidationMode): - this: AstCreator => - - private val DefaultsKey = "default" - private val ExportKeyword = "exports" - private val RequireKeyword = "require" - private val ImportKeyword = "import" - - private def hasName(json: Value): Boolean = hasKey(json, "id") && !json("id").isNull - - protected def codeForBabelNodeInfo(obj: BabelNodeInfo): Seq[String] = - val codes = obj.node match - case Identifier => Seq(obj.code) - case NumericLiteral => Seq(obj.code) - case StringLiteral => Seq(obj.code) - case AssignmentExpression => Seq(code(obj.json("left"))) - case ClassDeclaration => Seq(code(obj.json("id"))) - case TSTypeAliasDeclaration => Seq(code(obj.json("id"))) - case TSInterfaceDeclaration => Seq(code(obj.json("id"))) - case TSEnumDeclaration => Seq(code(obj.json("id"))) - case TSModuleDeclaration => Seq(code(obj.json("id"))) - case TSDeclareFunction if hasName(obj.json) => Seq(obj.json("id")("name").str) - case FunctionDeclaration if hasName(obj.json) => Seq(obj.json("id")("name").str) - case FunctionExpression if hasName(obj.json) => Seq(obj.json("id")("name").str) - case ClassExpression if hasName(obj.json) => Seq(obj.json("id")("name").str) - case VariableDeclarator if hasName(obj.json) => - createBabelNodeInfo(obj.json("id")).node match - case ArrayPattern => - obj.json("id")("elements").arr.toSeq.map(createBabelNodeInfo).map(_.code) - case ObjectPattern => - obj.json("id")("properties").arr.toSeq.flatMap(p => - codeForBabelNodeInfo(createBabelNodeInfo(p)) - ) - case _ => - Seq(obj.json("id")("name").str) - case VariableDeclarator => Seq(code(obj.json("id"))) - case MemberExpression => Seq(code(obj.json("property"))) - case ObjectProperty => Seq(code(obj.json("key"))) - case ObjectExpression => - obj.json("properties").arr.toSeq.flatMap(d => - codeForBabelNodeInfo(createBabelNodeInfo(d)) - ) - case VariableDeclaration => - obj.json("declarations").arr.toSeq.flatMap(d => - codeForBabelNodeInfo(createBabelNodeInfo(d)) + this: AstCreator => + + private val DefaultsKey = "default" + private val ExportKeyword = "exports" + private val RequireKeyword = "require" + private val ImportKeyword = "import" + + private def hasName(json: Value): Boolean = hasKey(json, "id") && !json("id").isNull + + protected def codeForBabelNodeInfo(obj: BabelNodeInfo): Seq[String] = + val codes = obj.node match + case Identifier => Seq(obj.code) + case NumericLiteral => Seq(obj.code) + case StringLiteral => Seq(obj.code) + case AssignmentExpression => Seq(code(obj.json("left"))) + case ClassDeclaration => Seq(code(obj.json("id"))) + case TSTypeAliasDeclaration => Seq(code(obj.json("id"))) + case TSInterfaceDeclaration => Seq(code(obj.json("id"))) + case TSEnumDeclaration => Seq(code(obj.json("id"))) + case TSModuleDeclaration => Seq(code(obj.json("id"))) + case TSDeclareFunction if hasName(obj.json) => Seq(obj.json("id")("name").str) + case FunctionDeclaration if hasName(obj.json) => Seq(obj.json("id")("name").str) + case FunctionExpression if hasName(obj.json) => Seq(obj.json("id")("name").str) + case ClassExpression if hasName(obj.json) => Seq(obj.json("id")("name").str) + case VariableDeclarator if hasName(obj.json) => + createBabelNodeInfo(obj.json("id")).node match + case ArrayPattern => + obj.json("id")("elements").arr.toSeq.map(createBabelNodeInfo).map(_.code) + case ObjectPattern => + obj.json("id")("properties").arr.toSeq.flatMap(p => + codeForBabelNodeInfo(createBabelNodeInfo(p)) ) - case _ => Seq.empty - codes.map(_.replace("...", "")) - end codeForBabelNodeInfo - - private def createExportCallAst( - name: String, - exportName: String, - declaration: BabelNodeInfo - ): Ast = - val exportCallAst = if name == DefaultsKey then - createIndexAccessCallAst( - identifierNode(declaration, exportName), - literalNode(declaration, s"\"$DefaultsKey\"", Option(Defines.String)), + case _ => + Seq(obj.json("id")("name").str) + case VariableDeclarator => Seq(code(obj.json("id"))) + case MemberExpression => Seq(code(obj.json("property"))) + case ObjectProperty => Seq(code(obj.json("key"))) + case ObjectExpression => + obj.json("properties").arr.toSeq.flatMap(d => + codeForBabelNodeInfo(createBabelNodeInfo(d)) + ) + case VariableDeclaration => + obj.json("declarations").arr.toSeq.flatMap(d => + codeForBabelNodeInfo(createBabelNodeInfo(d)) + ) + case _ => Seq.empty + codes.map(_.replace("...", "")) + end codeForBabelNodeInfo + + private def createExportCallAst( + name: String, + exportName: String, + declaration: BabelNodeInfo + ): Ast = + val exportCallAst = if name == DefaultsKey then + createIndexAccessCallAst( + identifierNode(declaration, exportName), + literalNode(declaration, s"\"$DefaultsKey\"", Option(Defines.String)), + declaration.lineNumber, + declaration.columnNumber + ) + else + createFieldAccessCallAst( + identifierNode(declaration, exportName), + createFieldIdentifierNode(name, declaration.lineNumber, declaration.columnNumber), + declaration.lineNumber, + declaration.columnNumber + ) + exportCallAst + end createExportCallAst + + private def createExportAssignmentCallAst( + name: String, + exportCallAst: Ast, + declaration: BabelNodeInfo, + from: Option[String] + ): Ast = + from match + case Some(value) => + val call = createFieldAccessCallAst( + identifierNode(declaration, value, Seq.empty), + createFieldIdentifierNode(name, declaration.lineNumber, declaration.columnNumber), declaration.lineNumber, declaration.columnNumber ) - else - createFieldAccessCallAst( - identifierNode(declaration, exportName), - createFieldIdentifierNode(name, declaration.lineNumber, declaration.columnNumber), + createAssignmentCallAst( + exportCallAst, + call, + s"${codeOf(exportCallAst.nodes.head)} = ${codeOf(call.nodes.head)}", declaration.lineNumber, declaration.columnNumber ) - exportCallAst - end createExportCallAst - - private def createExportAssignmentCallAst( - name: String, - exportCallAst: Ast, - declaration: BabelNodeInfo, - from: Option[String] - ): Ast = - from match - case Some(value) => - val call = createFieldAccessCallAst( - identifierNode(declaration, value, Seq.empty), - createFieldIdentifierNode(name, declaration.lineNumber, declaration.columnNumber), - declaration.lineNumber, - declaration.columnNumber - ) - createAssignmentCallAst( - exportCallAst, - call, - s"${codeOf(exportCallAst.nodes.head)} = ${codeOf(call.nodes.head)}", - declaration.lineNumber, - declaration.columnNumber - ) - case None => - createAssignmentCallAst( - exportCallAst, - Ast(identifierNode(declaration, name)), - s"${codeOf(exportCallAst.nodes.head)} = $name", - declaration.lineNumber, - declaration.columnNumber - ) - - private def extractDeclarationsFromExportDecl( - declaration: BabelNodeInfo, - key: String - ): Option[(Ast, Seq[String])] = - safeObj(declaration.json, key) - .map { d => - val nodeInfo = createBabelNodeInfo(d) - val ast = astForNodeWithFunctionReferenceAndCall(d) - val defaultName = codeForNodes(ast.nodes.toSeq) - val codes = codeForBabelNodeInfo(nodeInfo) - val names = if codes.isEmpty then defaultName.toSeq else codes - (ast, names) - } - - private def extractExportFromNameFromExportDecl(declaration: BabelNodeInfo): String = - safeObj(declaration.json, "source") - .map(d => s"_${stripQuotes(code(d))}") - .getOrElse(ExportKeyword) - - private def cleanImportName(name: String): String = if name.contains("/") then - val stripped = name.stripSuffix("/") - stripped.substring(stripped.lastIndexOf("/") + 1) - else name - - private def createAstForFrom(fromName: String, declaration: BabelNodeInfo): Ast = - if fromName == ExportKeyword then - Ast() - else - val strippedCode = cleanImportName(fromName).stripPrefix("_") - val id = identifierNode(declaration, s"_$strippedCode") - val localNode = newLocalNode(id.code, Defines.Any).order(0) - scope.addVariable(id.code, localNode, BlockScope) - scope.addVariableReference(id.code, id) - diffGraph.addEdge(localAstParentStack.head, localNode, EdgeTypes.AST) - - val sourceCallArgNode = - literalNode(declaration, s"\"${fromName.stripPrefix("_")}\"", None) - val sourceCall = - callNode( - declaration, - s"$RequireKeyword(${sourceCallArgNode.code})", - RequireKeyword, - DispatchTypes.STATIC_DISPATCH - ) - val sourceAst = - callAst(sourceCall, List(Ast(sourceCallArgNode))) - val assignmentCallAst = createAssignmentCallAst( - Ast(id), - sourceAst, - s"var ${codeOf(id)} = ${codeOf(sourceAst.nodes.head)}", + case None => + createAssignmentCallAst( + exportCallAst, + Ast(identifierNode(declaration, name)), + s"${codeOf(exportCallAst.nodes.head)} = $name", declaration.lineNumber, declaration.columnNumber ) - assignmentCallAst - - protected def astsForDecorators(elem: BabelNodeInfo): Seq[Ast] = - if hasKey(elem.json, "decorators") && !elem.json("decorators").isNull then - elem.json("decorators").arr.toList.map(d => astForDecorator(createBabelNodeInfo(d))) - else Seq.empty - - private def namesForDecoratorExpression(code: String): (String, String) = - val dotLastIndex = code.lastIndexOf(".") - if dotLastIndex != -1 then - (code.substring(dotLastIndex + 1), code) - else - (code, code) - - private def astForDecorator(decorator: BabelNodeInfo): Ast = - val exprNode = createBabelNodeInfo(decorator.json("expression")) - exprNode.node match - case Identifier | MemberExpression => - val (name, fullName) = namesForDecoratorExpression(code(exprNode.json)) - annotationAst(annotationNode(decorator, decorator.code, name, fullName), List.empty) - case CallExpression => - val (name, fullName) = namesForDecoratorExpression(code(exprNode.json("callee"))) - val node = annotationNode(decorator, decorator.code, name, fullName) - val assignmentAsts = exprNode.json("arguments").arr.toList.map { arg => - createBabelNodeInfo(arg).node match - case AssignmentExpression => - annotationAssignmentAst( - code(arg("left")), - code(arg), - astForNodeWithFunctionReference(arg("right")) - ) - case _ => - annotationAssignmentAst( - "value", - code(arg), - astForNodeWithFunctionReference(arg) - ) - } - annotationAst(node, assignmentAsts) - case _ => Ast() - end match - end astForDecorator - - protected def astForExportNamedDeclaration(declaration: BabelNodeInfo): Ast = - val specifiers = declaration - .json("specifiers") - .arr - .toList - .map { spec => - if createBabelNodeInfo(spec).node == ExportNamespaceSpecifier then - val exported = createBabelNodeInfo(spec("exported")) - (None, Option(exported)) - else - val exported = createBabelNodeInfo(spec("exported")) - val local = if hasKey(spec, "local") then - createBabelNodeInfo(spec("local")) - else - exported - (Option(local), Option(exported)) - } - - val exportName = extractExportFromNameFromExportDecl(declaration) - val fromAst = createAstForFrom(exportName, declaration) - val declAstAndNames = extractDeclarationsFromExportDecl(declaration, "declaration") - val declAsts = declAstAndNames.toList.flatMap { case (ast, names) => - ast +: names.map { name => - if exportName != ExportKeyword then - diffGraph.addNode(newDependencyNode( - name, - exportName.stripPrefix("_"), - RequireKeyword - )) - val exportCallAst = createExportCallAst(name, exportName, declaration) - createExportAssignmentCallAst(name, exportCallAst, declaration, None) - } - } - - val specifierAsts = specifiers.map { - case (Some(name), Some(alias)) => - val strippedCode = cleanImportName(exportName).stripPrefix("_") - val exportCallAst = createExportCallAst(alias.code, ExportKeyword, declaration) - if exportName != ExportKeyword then - diffGraph.addNode(newDependencyNode( - alias.code, - exportName.stripPrefix("_"), - RequireKeyword - )) - createExportAssignmentCallAst( - name.code, - exportCallAst, - declaration, - Option(s"_$strippedCode") - ) - else - createExportAssignmentCallAst(name.code, exportCallAst, declaration, None) - case (None, Some(alias)) => - diffGraph.addNode(newDependencyNode( - alias.code, - exportName.stripPrefix("_"), - RequireKeyword - )) - val exportCallAst = createExportCallAst(alias.code, ExportKeyword, declaration) - createExportAssignmentCallAst(exportName, exportCallAst, declaration, None) - case _ => Ast() - } - - val asts = fromAst +: (specifierAsts ++ declAsts) - setArgumentIndices(asts) - blockAst(createBlockNode(declaration), asts) - end astForExportNamedDeclaration - - protected def astForExportAssignment(assignment: BabelNodeInfo): Ast = - val expressionAstWithNames = extractDeclarationsFromExportDecl(assignment, "expression") - val declAsts = expressionAstWithNames.toList.flatMap { case (ast, names) => - ast +: names.map { name => - val exportCallAst = createExportCallAst(name, ExportKeyword, assignment) - createExportAssignmentCallAst(name, exportCallAst, assignment, None) - } - } - - setArgumentIndices(declAsts) - blockAst(createBlockNode(assignment), declAsts) - - protected def astForExportDefaultDeclaration(declaration: BabelNodeInfo): Ast = - val exportName = extractExportFromNameFromExportDecl(declaration) - val declAstAndNames = extractDeclarationsFromExportDecl(declaration, "declaration") - val declAsts = declAstAndNames.toList.flatMap { case (ast, names) => - ast +: names.map { name => - val exportCallAst = createExportCallAst(DefaultsKey, exportName, declaration) - createExportAssignmentCallAst(name, exportCallAst, declaration, None) - } - } - setArgumentIndices(declAsts) - blockAst(createBlockNode(declaration), declAsts) - - protected def astForExportAllDeclaration(declaration: BabelNodeInfo): Ast = - val exportName = extractExportFromNameFromExportDecl(declaration) - val depGroupId = stripQuotes(code(declaration.json("source"))) - val name = cleanImportName(depGroupId) - if exportName != ExportKeyword then - diffGraph.addNode(newDependencyNode(name, depGroupId, RequireKeyword)) - - val fromCallAst = createAstForFrom(exportName, declaration) - val exportCallAst = createExportCallAst(name, ExportKeyword, declaration) - val assignmentCallAst = - createExportAssignmentCallAst(s"_$name", exportCallAst, declaration, None) - - val childrenAsts = List(fromCallAst, assignmentCallAst) - setArgumentIndices(childrenAsts) - blockAst(createBlockNode(declaration), childrenAsts) - - protected def astForVariableDeclaration(declaration: BabelNodeInfo): Ast = - val kind = declaration.json("kind").str - val scopeType = if kind == "let" then - BlockScope - else - MethodScope - val declAsts = declaration.json("declarations").arr.toList.map(astForVariableDeclarator( - _, - scopeType, - kind - )) - declAsts match - case Nil => Ast() - case head :: Nil => head - case _ => blockAst(createBlockNode(declaration), declAsts) - - private def handleRequireCallForDependencies( - declarator: BabelNodeInfo, - lhs: Value, - rhs: Value, - call: Option[NewCall] - ): Unit = - val rhsCode = code(rhs) - val groupId = - rhsCode.substring(rhsCode.indexOf(s"$RequireKeyword(") + 9, rhsCode.indexOf(")") - 1) - val nodeInfo = createBabelNodeInfo(lhs) - val names = nodeInfo.node match - case ArrayPattern => nodeInfo.json("elements").arr.toList.map(code) - case ObjectPattern => nodeInfo.json("properties").arr.toList.map(code) - case _ => List(code(lhs)) - names.foreach { name => - val _dependencyNode = newDependencyNode(name, groupId, RequireKeyword) - diffGraph.addNode(_dependencyNode) - val importNode = createImportNodeAndAttachToCall(declarator, groupId, name, call) - diffGraph.addEdge(importNode, _dependencyNode, EdgeTypes.IMPORTS) - } - end handleRequireCallForDependencies - - private def astForVariableDeclarator( - declarator: Value, - scopeType: ScopeType, - kind: String - ): Ast = - val idNodeInfo = createBabelNodeInfo(declarator("id")) - val declNodeInfo = createBabelNodeInfo(declarator) - val initNodeInfo = Try(createBabelNodeInfo(declarator("init"))).toOption - val declaratorCode = s"$kind ${code(declarator)}" - val typeFullName = typeFor(declNodeInfo) - - val idName = idNodeInfo.node match - case Identifier => idNodeInfo.json("name").str - case _ => idNodeInfo.code - val localNode = newLocalNode(idName, typeFullName).order(0) - scope.addVariable(idName, localNode, scopeType) - diffGraph.addEdge(localAstParentStack.head, localNode, EdgeTypes.AST) - - if initNodeInfo.isEmpty then - Ast() - else - val sourceAst = initNodeInfo.get match - case requireCall if requireCall.code.startsWith(s"$RequireKeyword(") => - val call = astForNodeWithFunctionReference(requireCall.json) - handleRequireCallForDependencies( - createBabelNodeInfo(declarator), - idNodeInfo.json, - initNodeInfo.get.json, - call.root.map(_.asInstanceOf[NewCall]) - ) - call - case initExpr => - astForNodeWithFunctionReference(initExpr.json) - val nodeInfo = createBabelNodeInfo(idNodeInfo.json) - nodeInfo.node match - case ObjectPattern | ArrayPattern => - astForDeconstruction(nodeInfo, sourceAst, declaratorCode) - case _ => - val destAst = idNodeInfo.node match - case Identifier => astForIdentifier(idNodeInfo, Option(typeFullName)) - case _ => astForNode(idNodeInfo.json) - - val assignmentCallAst = - createAssignmentCallAst( - destAst, - sourceAst, - declaratorCode, - line = line(declarator), - column = column(declarator) - ) - assignmentCallAst - end if - end astForVariableDeclarator - - protected def astForTSImportEqualsDeclaration(impDecl: BabelNodeInfo): Ast = - val name = impDecl.json("id")("name").str - val referenceNode = createBabelNodeInfo(impDecl.json("moduleReference")) - val referenceName = referenceNode.node match - case TSExternalModuleReference => referenceNode.json("expression")("value").str - case _ => referenceNode.code - val _dependencyNode = newDependencyNode(name, referenceName, ImportKeyword) - diffGraph.addNode(_dependencyNode) - val assignment = - astForRequireCallFromImport(name, None, referenceName, isImportN = false, impDecl) - val call = assignment.nodes.collectFirst { case x: NewCall if x.name == "require" => x } - val importNode = - createImportNodeAndAttachToCall(impDecl, referenceName, name, call) - diffGraph.addEdge(importNode, _dependencyNode, EdgeTypes.IMPORTS) - assignment - private def astForRequireCallFromImport( - name: String, - alias: Option[String], - from: String, - isImportN: Boolean, - nodeInfo: BabelNodeInfo - ): Ast = - val destName = alias.getOrElse(name) - val destNode = identifierNode(nodeInfo, destName) - val localNode = newLocalNode(destName, Defines.Any).order(0) - scope.addVariable(destName, localNode, BlockScope) - scope.addVariableReference(destName, destNode) + private def extractDeclarationsFromExportDecl( + declaration: BabelNodeInfo, + key: String + ): Option[(Ast, Seq[String])] = + safeObj(declaration.json, key) + .map { d => + val nodeInfo = createBabelNodeInfo(d) + val ast = astForNodeWithFunctionReferenceAndCall(d) + val defaultName = codeForNodes(ast.nodes.toSeq) + val codes = codeForBabelNodeInfo(nodeInfo) + val names = if codes.isEmpty then defaultName.toSeq else codes + (ast, names) + } + + private def extractExportFromNameFromExportDecl(declaration: BabelNodeInfo): String = + safeObj(declaration.json, "source") + .map(d => s"_${stripQuotes(code(d))}") + .getOrElse(ExportKeyword) + + private def cleanImportName(name: String): String = if name.contains("/") then + val stripped = name.stripSuffix("/") + stripped.substring(stripped.lastIndexOf("/") + 1) + else name + + private def createAstForFrom(fromName: String, declaration: BabelNodeInfo): Ast = + if fromName == ExportKeyword then + Ast() + else + val strippedCode = cleanImportName(fromName).stripPrefix("_") + val id = identifierNode(declaration, s"_$strippedCode") + val localNode = newLocalNode(id.code, Defines.Any).order(0) + scope.addVariable(id.code, localNode, BlockScope) + scope.addVariableReference(id.code, id) diffGraph.addEdge(localAstParentStack.head, localNode, EdgeTypes.AST) - val destAst = Ast(destNode) - val sourceCallArgNode = literalNode(nodeInfo, s"\"$from\"", None) + val sourceCallArgNode = + literalNode(declaration, s"\"${fromName.stripPrefix("_")}\"", None) val sourceCall = callNode( - nodeInfo, + declaration, s"$RequireKeyword(${sourceCallArgNode.code})", RequireKeyword, - DispatchTypes.DYNAMIC_DISPATCH + DispatchTypes.STATIC_DISPATCH ) - - val receiverNode = identifierNode(nodeInfo, RequireKeyword) - val thisNode = - identifierNode(nodeInfo, "this").dynamicTypeHintFullName(typeHintForThisExpression()) - scope.addVariableReference(thisNode.name, thisNode) - val cAst = callAst( - sourceCall, - List(Ast(sourceCallArgNode)), - receiver = Option(Ast(receiverNode)), - base = Option(Ast(thisNode)) + val sourceAst = + callAst(sourceCall, List(Ast(sourceCallArgNode))) + val assignmentCallAst = createAssignmentCallAst( + Ast(id), + sourceAst, + s"var ${codeOf(id)} = ${codeOf(sourceAst.nodes.head)}", + declaration.lineNumber, + declaration.columnNumber ) - val sourceAst = if isImportN then - val fieldAccessCall = createFieldAccessCallAst( - cAst, - createFieldIdentifierNode(name, nodeInfo.lineNumber, nodeInfo.columnNumber), - nodeInfo.lineNumber, - nodeInfo.columnNumber - ) - fieldAccessCall - else - cAst - val assigmentCallAst = - createAssignmentCallAst( - destAst, - sourceAst, - s"var ${codeOf(destAst.nodes.head)} = ${codeOf(sourceAst.nodes.head)}", - nodeInfo.lineNumber, - nodeInfo.columnNumber + assignmentCallAst + + protected def astsForDecorators(elem: BabelNodeInfo): Seq[Ast] = + if hasKey(elem.json, "decorators") && !elem.json("decorators").isNull then + elem.json("decorators").arr.toList.map(d => astForDecorator(createBabelNodeInfo(d))) + else Seq.empty + + private def namesForDecoratorExpression(code: String): (String, String) = + val dotLastIndex = code.lastIndexOf(".") + if dotLastIndex != -1 then + (code.substring(dotLastIndex + 1), code) + else + (code, code) + + private def astForDecorator(decorator: BabelNodeInfo): Ast = + val exprNode = createBabelNodeInfo(decorator.json("expression")) + exprNode.node match + case Identifier | MemberExpression => + val (name, fullName) = namesForDecoratorExpression(code(exprNode.json)) + annotationAst(annotationNode(decorator, decorator.code, name, fullName), List.empty) + case CallExpression => + val (name, fullName) = namesForDecoratorExpression(code(exprNode.json("callee"))) + val node = annotationNode(decorator, decorator.code, name, fullName) + val assignmentAsts = exprNode.json("arguments").arr.toList.map { arg => + createBabelNodeInfo(arg).node match + case AssignmentExpression => + annotationAssignmentAst( + code(arg("left")), + code(arg), + astForNodeWithFunctionReference(arg("right")) + ) + case _ => + annotationAssignmentAst( + "value", + code(arg), + astForNodeWithFunctionReference(arg) + ) + } + annotationAst(node, assignmentAsts) + case _ => Ast() + end match + end astForDecorator + + protected def astForExportNamedDeclaration(declaration: BabelNodeInfo): Ast = + val specifiers = declaration + .json("specifiers") + .arr + .toList + .map { spec => + if createBabelNodeInfo(spec).node == ExportNamespaceSpecifier then + val exported = createBabelNodeInfo(spec("exported")) + (None, Option(exported)) + else + val exported = createBabelNodeInfo(spec("exported")) + val local = if hasKey(spec, "local") then + createBabelNodeInfo(spec("local")) + else + exported + (Option(local), Option(exported)) + } + + val exportName = extractExportFromNameFromExportDecl(declaration) + val fromAst = createAstForFrom(exportName, declaration) + val declAstAndNames = extractDeclarationsFromExportDecl(declaration, "declaration") + val declAsts = declAstAndNames.toList.flatMap { case (ast, names) => + ast +: names.map { name => + if exportName != ExportKeyword then + diffGraph.addNode(newDependencyNode( + name, + exportName.stripPrefix("_"), + RequireKeyword + )) + val exportCallAst = createExportCallAst(name, exportName, declaration) + createExportAssignmentCallAst(name, exportCallAst, declaration, None) + } + } + + val specifierAsts = specifiers.map { + case (Some(name), Some(alias)) => + val strippedCode = cleanImportName(exportName).stripPrefix("_") + val exportCallAst = createExportCallAst(alias.code, ExportKeyword, declaration) + if exportName != ExportKeyword then + diffGraph.addNode(newDependencyNode( + alias.code, + exportName.stripPrefix("_"), + RequireKeyword + )) + createExportAssignmentCallAst( + name.code, + exportCallAst, + declaration, + Option(s"_$strippedCode") + ) + else + createExportAssignmentCallAst(name.code, exportCallAst, declaration, None) + case (None, Some(alias)) => + diffGraph.addNode(newDependencyNode( + alias.code, + exportName.stripPrefix("_"), + RequireKeyword + )) + val exportCallAst = createExportCallAst(alias.code, ExportKeyword, declaration) + createExportAssignmentCallAst(exportName, exportCallAst, declaration, None) + case _ => Ast() + } + + val asts = fromAst +: (specifierAsts ++ declAsts) + setArgumentIndices(asts) + blockAst(createBlockNode(declaration), asts) + end astForExportNamedDeclaration + + protected def astForExportAssignment(assignment: BabelNodeInfo): Ast = + val expressionAstWithNames = extractDeclarationsFromExportDecl(assignment, "expression") + val declAsts = expressionAstWithNames.toList.flatMap { case (ast, names) => + ast +: names.map { name => + val exportCallAst = createExportCallAst(name, ExportKeyword, assignment) + createExportAssignmentCallAst(name, exportCallAst, assignment, None) + } + } + + setArgumentIndices(declAsts) + blockAst(createBlockNode(assignment), declAsts) + + protected def astForExportDefaultDeclaration(declaration: BabelNodeInfo): Ast = + val exportName = extractExportFromNameFromExportDecl(declaration) + val declAstAndNames = extractDeclarationsFromExportDecl(declaration, "declaration") + val declAsts = declAstAndNames.toList.flatMap { case (ast, names) => + ast +: names.map { name => + val exportCallAst = createExportCallAst(DefaultsKey, exportName, declaration) + createExportAssignmentCallAst(name, exportCallAst, declaration, None) + } + } + setArgumentIndices(declAsts) + blockAst(createBlockNode(declaration), declAsts) + + protected def astForExportAllDeclaration(declaration: BabelNodeInfo): Ast = + val exportName = extractExportFromNameFromExportDecl(declaration) + val depGroupId = stripQuotes(code(declaration.json("source"))) + val name = cleanImportName(depGroupId) + if exportName != ExportKeyword then + diffGraph.addNode(newDependencyNode(name, depGroupId, RequireKeyword)) + + val fromCallAst = createAstForFrom(exportName, declaration) + val exportCallAst = createExportCallAst(name, ExportKeyword, declaration) + val assignmentCallAst = + createExportAssignmentCallAst(s"_$name", exportCallAst, declaration, None) + + val childrenAsts = List(fromCallAst, assignmentCallAst) + setArgumentIndices(childrenAsts) + blockAst(createBlockNode(declaration), childrenAsts) + + protected def astForVariableDeclaration(declaration: BabelNodeInfo): Ast = + val kind = declaration.json("kind").str + val scopeType = if kind == "let" then + BlockScope + else + MethodScope + val declAsts = declaration.json("declarations").arr.toList.map(astForVariableDeclarator( + _, + scopeType, + kind + )) + declAsts match + case Nil => Ast() + case head :: Nil => head + case _ => blockAst(createBlockNode(declaration), declAsts) + + private def handleRequireCallForDependencies( + declarator: BabelNodeInfo, + lhs: Value, + rhs: Value, + call: Option[NewCall] + ): Unit = + val rhsCode = code(rhs) + val groupId = + rhsCode.substring(rhsCode.indexOf(s"$RequireKeyword(") + 9, rhsCode.indexOf(")") - 1) + val nodeInfo = createBabelNodeInfo(lhs) + val names = nodeInfo.node match + case ArrayPattern => nodeInfo.json("elements").arr.toList.map(code) + case ObjectPattern => nodeInfo.json("properties").arr.toList.map(code) + case _ => List(code(lhs)) + names.foreach { name => + val _dependencyNode = newDependencyNode(name, groupId, RequireKeyword) + diffGraph.addNode(_dependencyNode) + val importNode = createImportNodeAndAttachToCall(declarator, groupId, name, call) + diffGraph.addEdge(importNode, _dependencyNode, EdgeTypes.IMPORTS) + } + end handleRequireCallForDependencies + + private def astForVariableDeclarator( + declarator: Value, + scopeType: ScopeType, + kind: String + ): Ast = + val idNodeInfo = createBabelNodeInfo(declarator("id")) + val declNodeInfo = createBabelNodeInfo(declarator) + val initNodeInfo = Try(createBabelNodeInfo(declarator("init"))).toOption + val declaratorCode = s"$kind ${code(declarator)}" + val typeFullName = typeFor(declNodeInfo) + + val idName = idNodeInfo.node match + case Identifier => idNodeInfo.json("name").str + case _ => idNodeInfo.code + val localNode = newLocalNode(idName, typeFullName).order(0) + scope.addVariable(idName, localNode, scopeType) + diffGraph.addEdge(localAstParentStack.head, localNode, EdgeTypes.AST) + + if initNodeInfo.isEmpty then + Ast() + else + val sourceAst = initNodeInfo.get match + case requireCall if requireCall.code.startsWith(s"$RequireKeyword(") => + val call = astForNodeWithFunctionReference(requireCall.json) + handleRequireCallForDependencies( + createBabelNodeInfo(declarator), + idNodeInfo.json, + initNodeInfo.get.json, + call.root.map(_.asInstanceOf[NewCall]) ) - assigmentCallAst - end astForRequireCallFromImport - - protected def astForImportDeclaration(impDecl: BabelNodeInfo): Ast = - val source = impDecl.json("source")("value").str - val specifiers = impDecl.json("specifiers").arr - - if specifiers.isEmpty then - val _dependencyNode = newDependencyNode(source, source, ImportKeyword) - diffGraph.addNode(_dependencyNode) - val assignment = - astForRequireCallFromImport(source, None, source, isImportN = false, impDecl) - val call = assignment.nodes.collectFirst { case x: NewCall if x.name == "require" => x } - val importNode = createImportNodeAndAttachToCall(impDecl, source, source, call) - diffGraph.addEdge(importNode, _dependencyNode, EdgeTypes.IMPORTS) - assignment - else - val specs = impDecl.json("specifiers").arr.toList - val requireCalls = specs.map { importSpecifier => - val isImportN = createBabelNodeInfo(importSpecifier).node match - case ImportSpecifier => true - case _ => false - val name = importSpecifier("local")("name").str - val (alias, reqName) = reqNameFromImportSpecifier(importSpecifier, name) - val assignment = astForRequireCallFromImport( - reqName, - alias, - source, - isImportN = isImportN, - impDecl - ) - val importedName = importSpecifier("local")("name").str - val call = - assignment.nodes.collectFirst { case x: NewCall if x.name == "require" => x } - val importNode = createImportNodeAndAttachToCall( - impDecl, - s"$source:$reqName", - importedName, - call + call + case initExpr => + astForNodeWithFunctionReference(initExpr.json) + val nodeInfo = createBabelNodeInfo(idNodeInfo.json) + nodeInfo.node match + case ObjectPattern | ArrayPattern => + astForDeconstruction(nodeInfo, sourceAst, declaratorCode) + case _ => + val destAst = idNodeInfo.node match + case Identifier => astForIdentifier(idNodeInfo, Option(typeFullName)) + case _ => astForNode(idNodeInfo.json) + + val assignmentCallAst = + createAssignmentCallAst( + destAst, + sourceAst, + declaratorCode, + line = line(declarator), + column = column(declarator) ) - val _dependencyNode = newDependencyNode(importedName, source, ImportKeyword) - diffGraph.addEdge(importNode, _dependencyNode, EdgeTypes.IMPORTS) - diffGraph.addNode(_dependencyNode) - assignment - } - if requireCalls.isEmpty then - Ast() - else if requireCalls.sizeIs == 1 then - requireCalls.head - else - blockAst(createBlockNode(impDecl), requireCalls) - end if - end astForImportDeclaration - - private def reqNameFromImportSpecifier(importSpecifier: Value, name: String) = - if hasKey(importSpecifier, "imported") then - (Option(name), importSpecifier("imported")("name").str) - else - (None, name) - - private def createImportNodeAndAttachToCall( - impDecl: BabelNodeInfo, - importedEntity: String, - importedAs: String, - call: Option[NewCall] - ): NewImport = - createImportNodeAndAttachToCall( - impDecl.code.stripSuffix(";"), - importedEntity, - importedAs, - call + assignmentCallAst + end if + end astForVariableDeclarator + + protected def astForTSImportEqualsDeclaration(impDecl: BabelNodeInfo): Ast = + val name = impDecl.json("id")("name").str + val referenceNode = createBabelNodeInfo(impDecl.json("moduleReference")) + val referenceName = referenceNode.node match + case TSExternalModuleReference => referenceNode.json("expression")("value").str + case _ => referenceNode.code + val _dependencyNode = newDependencyNode(name, referenceName, ImportKeyword) + diffGraph.addNode(_dependencyNode) + val assignment = + astForRequireCallFromImport(name, None, referenceName, isImportN = false, impDecl) + val call = assignment.nodes.collectFirst { case x: NewCall if x.name == "require" => x } + val importNode = + createImportNodeAndAttachToCall(impDecl, referenceName, name, call) + diffGraph.addEdge(importNode, _dependencyNode, EdgeTypes.IMPORTS) + assignment + + private def astForRequireCallFromImport( + name: String, + alias: Option[String], + from: String, + isImportN: Boolean, + nodeInfo: BabelNodeInfo + ): Ast = + val destName = alias.getOrElse(name) + val destNode = identifierNode(nodeInfo, destName) + val localNode = newLocalNode(destName, Defines.Any).order(0) + scope.addVariable(destName, localNode, BlockScope) + scope.addVariableReference(destName, destNode) + diffGraph.addEdge(localAstParentStack.head, localNode, EdgeTypes.AST) + + val destAst = Ast(destNode) + val sourceCallArgNode = literalNode(nodeInfo, s"\"$from\"", None) + val sourceCall = + callNode( + nodeInfo, + s"$RequireKeyword(${sourceCallArgNode.code})", + RequireKeyword, + DispatchTypes.DYNAMIC_DISPATCH ) - private def createImportNodeAndAttachToCall( - code: String, - importedEntity: String, - importedAs: String, - call: Option[NewCall] - ): NewImport = - val impNode = NewImport() - .code(code) - .importedEntity(importedEntity) - .importedAs(importedAs) - .lineNumber(call.flatMap(_.lineNumber)) - .columnNumber(call.flatMap(_.lineNumber)) - call.foreach { c => diffGraph.addEdge(c, impNode, EdgeTypes.IS_CALL_FOR_IMPORT) } - impNode - - private def convertDestructingObjectElement( - element: BabelNodeInfo, - key: BabelNodeInfo, - localTmpName: String - ): Ast = - val valueAst = astForNode(element.json) - - val localNode = newLocalNode(element.code, Defines.Any).order(0) - diffGraph.addEdge(localAstParentStack.head, localNode, EdgeTypes.AST) - scope.addVariable(element.code, localNode, MethodScope) - - val fieldAccessTmpNode = identifierNode(element, localTmpName) - val keyNode = createFieldIdentifierNode(key.code, key.lineNumber, key.columnNumber) - val accessAst = createFieldAccessCallAst( - fieldAccessTmpNode, - keyNode, - element.lineNumber, - element.columnNumber - ) + val receiverNode = identifierNode(nodeInfo, RequireKeyword) + val thisNode = + identifierNode(nodeInfo, "this").dynamicTypeHintFullName(typeHintForThisExpression()) + scope.addVariableReference(thisNode.name, thisNode) + val cAst = callAst( + sourceCall, + List(Ast(sourceCallArgNode)), + receiver = Option(Ast(receiverNode)), + base = Option(Ast(thisNode)) + ) + val sourceAst = if isImportN then + val fieldAccessCall = createFieldAccessCallAst( + cAst, + createFieldIdentifierNode(name, nodeInfo.lineNumber, nodeInfo.columnNumber), + nodeInfo.lineNumber, + nodeInfo.columnNumber + ) + fieldAccessCall + else + cAst + val assigmentCallAst = createAssignmentCallAst( - valueAst, - accessAst, - s"${codeOf(valueAst.nodes.head)} = ${codeOf(accessAst.nodes.head)}", - element.lineNumber, - element.columnNumber + destAst, + sourceAst, + s"var ${codeOf(destAst.nodes.head)} = ${codeOf(sourceAst.nodes.head)}", + nodeInfo.lineNumber, + nodeInfo.columnNumber ) - end convertDestructingObjectElement - - private def convertDestructingArrayElement( - element: BabelNodeInfo, - index: Int, - localTmpName: String - ): Ast = - val valueAst = astForNode(element.json) - - val localNode = newLocalNode(element.code, Defines.Any).order(0) - diffGraph.addEdge(localAstParentStack.head, localNode, EdgeTypes.AST) - scope.addVariable(element.code, localNode, MethodScope) - - val fieldAccessTmpNode = identifierNode(element, localTmpName) - val keyNode = literalNode(element, index.toString, Option(Defines.Number)) - val accessAst = createIndexAccessCallAst( - fieldAccessTmpNode, - keyNode, - element.lineNumber, - element.columnNumber + assigmentCallAst + end astForRequireCallFromImport + + protected def astForImportDeclaration(impDecl: BabelNodeInfo): Ast = + val source = impDecl.json("source")("value").str + val specifiers = impDecl.json("specifiers").arr + + if specifiers.isEmpty then + val _dependencyNode = newDependencyNode(source, source, ImportKeyword) + diffGraph.addNode(_dependencyNode) + val assignment = + astForRequireCallFromImport(source, None, source, isImportN = false, impDecl) + val call = assignment.nodes.collectFirst { case x: NewCall if x.name == "require" => x } + val importNode = createImportNodeAndAttachToCall(impDecl, source, source, call) + diffGraph.addEdge(importNode, _dependencyNode, EdgeTypes.IMPORTS) + assignment + else + val specs = impDecl.json("specifiers").arr.toList + val requireCalls = specs.map { importSpecifier => + val isImportN = createBabelNodeInfo(importSpecifier).node match + case ImportSpecifier => true + case _ => false + val name = importSpecifier("local")("name").str + val (alias, reqName) = reqNameFromImportSpecifier(importSpecifier, name) + val assignment = astForRequireCallFromImport( + reqName, + alias, + source, + isImportN = isImportN, + impDecl ) - createAssignmentCallAst( - valueAst, - accessAst, - s"${codeOf(valueAst.nodes.head)} = ${codeOf(accessAst.nodes.head)}", - element.lineNumber, - element.columnNumber + val importedName = importSpecifier("local")("name").str + val call = + assignment.nodes.collectFirst { case x: NewCall if x.name == "require" => x } + val importNode = createImportNodeAndAttachToCall( + impDecl, + s"$source:$reqName", + importedName, + call ) - end convertDestructingArrayElement - - private def convertDestructingArrayElementWithDefault( - element: BabelNodeInfo, - index: Int, - localTmpName: String - ): Ast = - val rhsElement = element.json("right") - val rhsAst = astForNodeWithFunctionReference(rhsElement) - - val lhsElement = element.json("left") - val nodeInfo = createBabelNodeInfo(lhsElement) - val lhsAst = nodeInfo.node match - case ObjectPattern | ArrayPattern => - val sourceAst = - astForNodeWithFunctionReference(createBabelNodeInfo(rhsElement).json) - astForDeconstruction(nodeInfo, sourceAst, element.code) - case _ => astForNodeWithFunctionReference(lhsElement) - - val testAst = - val fieldAccessTmpNode = identifierNode(element, localTmpName) - val keyNode = literalNode(element, index.toString, Option(Defines.Number)) - val accessAst = - createIndexAccessCallAst( - fieldAccessTmpNode, - keyNode, - element.lineNumber, - element.columnNumber - ) - val voidCallNode = createVoidCallNode(element.lineNumber, element.columnNumber) - createEqualsCallAst( - accessAst, - Ast(voidCallNode), - element.lineNumber, - element.columnNumber - ) - val falseAst = - val fieldAccessTmpNode = identifierNode(element, localTmpName) - val keyNode = literalNode(element, index.toString, Option(Defines.Number)) - createIndexAccessCallAst( - fieldAccessTmpNode, - keyNode, - element.lineNumber, - element.columnNumber - ) - val ternaryNodeAst = - createTernaryCallAst( - testAst, - rhsAst, - falseAst, - element.lineNumber, - element.columnNumber - ) - createAssignmentCallAst( - lhsAst, - ternaryNodeAst, - s"${codeOf(lhsAst.nodes.head)} = ${codeOf(ternaryNodeAst.nodes.head)}", + val _dependencyNode = newDependencyNode(importedName, source, ImportKeyword) + diffGraph.addEdge(importNode, _dependencyNode, EdgeTypes.IMPORTS) + diffGraph.addNode(_dependencyNode) + assignment + } + if requireCalls.isEmpty then + Ast() + else if requireCalls.sizeIs == 1 then + requireCalls.head + else + blockAst(createBlockNode(impDecl), requireCalls) + end if + end astForImportDeclaration + + private def reqNameFromImportSpecifier(importSpecifier: Value, name: String) = + if hasKey(importSpecifier, "imported") then + (Option(name), importSpecifier("imported")("name").str) + else + (None, name) + + private def createImportNodeAndAttachToCall( + impDecl: BabelNodeInfo, + importedEntity: String, + importedAs: String, + call: Option[NewCall] + ): NewImport = + createImportNodeAndAttachToCall( + impDecl.code.stripSuffix(";"), + importedEntity, + importedAs, + call + ) + + private def createImportNodeAndAttachToCall( + code: String, + importedEntity: String, + importedAs: String, + call: Option[NewCall] + ): NewImport = + val impNode = NewImport() + .code(code) + .importedEntity(importedEntity) + .importedAs(importedAs) + .lineNumber(call.flatMap(_.lineNumber)) + .columnNumber(call.flatMap(_.lineNumber)) + call.foreach { c => diffGraph.addEdge(c, impNode, EdgeTypes.IS_CALL_FOR_IMPORT) } + impNode + + private def convertDestructingObjectElement( + element: BabelNodeInfo, + key: BabelNodeInfo, + localTmpName: String + ): Ast = + val valueAst = astForNode(element.json) + + val localNode = newLocalNode(element.code, Defines.Any).order(0) + diffGraph.addEdge(localAstParentStack.head, localNode, EdgeTypes.AST) + scope.addVariable(element.code, localNode, MethodScope) + + val fieldAccessTmpNode = identifierNode(element, localTmpName) + val keyNode = createFieldIdentifierNode(key.code, key.lineNumber, key.columnNumber) + val accessAst = createFieldAccessCallAst( + fieldAccessTmpNode, + keyNode, + element.lineNumber, + element.columnNumber + ) + createAssignmentCallAst( + valueAst, + accessAst, + s"${codeOf(valueAst.nodes.head)} = ${codeOf(accessAst.nodes.head)}", + element.lineNumber, + element.columnNumber + ) + end convertDestructingObjectElement + + private def convertDestructingArrayElement( + element: BabelNodeInfo, + index: Int, + localTmpName: String + ): Ast = + val valueAst = astForNode(element.json) + + val localNode = newLocalNode(element.code, Defines.Any).order(0) + diffGraph.addEdge(localAstParentStack.head, localNode, EdgeTypes.AST) + scope.addVariable(element.code, localNode, MethodScope) + + val fieldAccessTmpNode = identifierNode(element, localTmpName) + val keyNode = literalNode(element, index.toString, Option(Defines.Number)) + val accessAst = createIndexAccessCallAst( + fieldAccessTmpNode, + keyNode, + element.lineNumber, + element.columnNumber + ) + createAssignmentCallAst( + valueAst, + accessAst, + s"${codeOf(valueAst.nodes.head)} = ${codeOf(accessAst.nodes.head)}", + element.lineNumber, + element.columnNumber + ) + end convertDestructingArrayElement + + private def convertDestructingArrayElementWithDefault( + element: BabelNodeInfo, + index: Int, + localTmpName: String + ): Ast = + val rhsElement = element.json("right") + val rhsAst = astForNodeWithFunctionReference(rhsElement) + + val lhsElement = element.json("left") + val nodeInfo = createBabelNodeInfo(lhsElement) + val lhsAst = nodeInfo.node match + case ObjectPattern | ArrayPattern => + val sourceAst = + astForNodeWithFunctionReference(createBabelNodeInfo(rhsElement).json) + astForDeconstruction(nodeInfo, sourceAst, element.code) + case _ => astForNodeWithFunctionReference(lhsElement) + + val testAst = + val fieldAccessTmpNode = identifierNode(element, localTmpName) + val keyNode = literalNode(element, index.toString, Option(Defines.Number)) + val accessAst = + createIndexAccessCallAst( + fieldAccessTmpNode, + keyNode, + element.lineNumber, + element.columnNumber + ) + val voidCallNode = createVoidCallNode(element.lineNumber, element.columnNumber) + createEqualsCallAst( + accessAst, + Ast(voidCallNode), + element.lineNumber, + element.columnNumber + ) + val falseAst = + val fieldAccessTmpNode = identifierNode(element, localTmpName) + val keyNode = literalNode(element, index.toString, Option(Defines.Number)) + createIndexAccessCallAst( + fieldAccessTmpNode, + keyNode, + element.lineNumber, + element.columnNumber + ) + val ternaryNodeAst = + createTernaryCallAst( + testAst, + rhsAst, + falseAst, element.lineNumber, element.columnNumber ) - end convertDestructingArrayElementWithDefault - - private def convertDestructingObjectElementWithDefault( - element: BabelNodeInfo, - key: BabelNodeInfo, - localTmpName: String - ): Ast = - val rhsElement = element.json("right") - val rhsAst = astForNodeWithFunctionReference(rhsElement) - - val lhsElement = element.json("left") - val nodeInfo = createBabelNodeInfo(lhsElement) - val lhsAst = nodeInfo.node match - case ObjectPattern | ArrayPattern => - val sourceAst = - astForNodeWithFunctionReference(createBabelNodeInfo(rhsElement).json) - astForDeconstruction(nodeInfo, sourceAst, element.code) - case _ => astForNodeWithFunctionReference(lhsElement) - - val testAst = - val fieldAccessTmpNode = identifierNode(element, localTmpName) - val keyNode = createFieldIdentifierNode(key.code, key.lineNumber, key.columnNumber) - val accessAst = - createFieldAccessCallAst( - fieldAccessTmpNode, - keyNode, - element.lineNumber, - element.columnNumber - ) - val voidCallNode = createVoidCallNode(element.lineNumber, element.columnNumber) - createEqualsCallAst( - accessAst, - Ast(voidCallNode), - element.lineNumber, - element.columnNumber - ) - val falseAst = - val fieldAccessTmpNode = identifierNode(element, localTmpName) - val keyNode = createFieldIdentifierNode(key.code, key.lineNumber, key.columnNumber) - createFieldAccessCallAst( - fieldAccessTmpNode, - keyNode, - element.lineNumber, - element.columnNumber - ) - val ternaryNodeAst = - createTernaryCallAst( - testAst, - rhsAst, - falseAst, - element.lineNumber, - element.columnNumber - ) - createAssignmentCallAst( - lhsAst, - ternaryNodeAst, - s"${codeOf(lhsAst.nodes.head)} = ${codeOf(ternaryNodeAst.nodes.head)}", + createAssignmentCallAst( + lhsAst, + ternaryNodeAst, + s"${codeOf(lhsAst.nodes.head)} = ${codeOf(ternaryNodeAst.nodes.head)}", + element.lineNumber, + element.columnNumber + ) + end convertDestructingArrayElementWithDefault + + private def convertDestructingObjectElementWithDefault( + element: BabelNodeInfo, + key: BabelNodeInfo, + localTmpName: String + ): Ast = + val rhsElement = element.json("right") + val rhsAst = astForNodeWithFunctionReference(rhsElement) + + val lhsElement = element.json("left") + val nodeInfo = createBabelNodeInfo(lhsElement) + val lhsAst = nodeInfo.node match + case ObjectPattern | ArrayPattern => + val sourceAst = + astForNodeWithFunctionReference(createBabelNodeInfo(rhsElement).json) + astForDeconstruction(nodeInfo, sourceAst, element.code) + case _ => astForNodeWithFunctionReference(lhsElement) + + val testAst = + val fieldAccessTmpNode = identifierNode(element, localTmpName) + val keyNode = createFieldIdentifierNode(key.code, key.lineNumber, key.columnNumber) + val accessAst = + createFieldAccessCallAst( + fieldAccessTmpNode, + keyNode, + element.lineNumber, + element.columnNumber + ) + val voidCallNode = createVoidCallNode(element.lineNumber, element.columnNumber) + createEqualsCallAst( + accessAst, + Ast(voidCallNode), + element.lineNumber, + element.columnNumber + ) + val falseAst = + val fieldAccessTmpNode = identifierNode(element, localTmpName) + val keyNode = createFieldIdentifierNode(key.code, key.lineNumber, key.columnNumber) + createFieldAccessCallAst( + fieldAccessTmpNode, + keyNode, + element.lineNumber, + element.columnNumber + ) + val ternaryNodeAst = + createTernaryCallAst( + testAst, + rhsAst, + falseAst, element.lineNumber, element.columnNumber ) - end convertDestructingObjectElementWithDefault - - private def createParamAst(pattern: BabelNodeInfo, keyName: String, sourceAst: Ast): Ast = - val testAst = - val lhsNode = identifierNode(pattern, keyName) - scope.addVariableReference(keyName, lhsNode) - val rhsNode = - callNode(pattern, "void 0", ".void", DispatchTypes.STATIC_DISPATCH) - createEqualsCallAst( - Ast(lhsNode), - Ast(rhsNode), - pattern.lineNumber, - pattern.columnNumber - ) - - val falseNode = - val initNode = identifierNode(pattern, keyName) - scope.addVariableReference(keyName, initNode) - initNode - createTernaryCallAst( - testAst, - sourceAst, - Ast(falseNode), + createAssignmentCallAst( + lhsAst, + ternaryNodeAst, + s"${codeOf(lhsAst.nodes.head)} = ${codeOf(ternaryNodeAst.nodes.head)}", + element.lineNumber, + element.columnNumber + ) + end convertDestructingObjectElementWithDefault + + private def createParamAst(pattern: BabelNodeInfo, keyName: String, sourceAst: Ast): Ast = + val testAst = + val lhsNode = identifierNode(pattern, keyName) + scope.addVariableReference(keyName, lhsNode) + val rhsNode = + callNode(pattern, "void 0", ".void", DispatchTypes.STATIC_DISPATCH) + createEqualsCallAst( + Ast(lhsNode), + Ast(rhsNode), + pattern.lineNumber, + pattern.columnNumber + ) + + val falseNode = + val initNode = identifierNode(pattern, keyName) + scope.addVariableReference(keyName, initNode) + initNode + createTernaryCallAst( + testAst, + sourceAst, + Ast(falseNode), + pattern.lineNumber, + pattern.columnNumber + ) + end createParamAst + + protected def astForDeconstruction( + pattern: BabelNodeInfo, + sourceAst: Ast, + code: String, + paramName: Option[String] = None + ): Ast = + val localTmpName = generateUnusedVariableName(usedVariableNames, "_tmp") + + val blockNode = createBlockNode(pattern, Option(code)) + scope.pushNewBlockScope(blockNode) + localAstParentStack.push(blockNode) + + val localNode = newLocalNode(localTmpName, Defines.Any).order(0) + val tmpNode = identifierNode(pattern, localTmpName) + diffGraph.addEdge(localAstParentStack.head, localNode, EdgeTypes.AST) + scope.addVariable(localTmpName, localNode, BlockScope) + scope.addVariableReference(localTmpName, tmpNode) + + val rhsAssignmentAst = + paramName.map(createParamAst(pattern, _, sourceAst)).getOrElse(sourceAst) + val assignmentTmpCallAst = + createAssignmentCallAst( + Ast(tmpNode), + rhsAssignmentAst, + s"$localTmpName = ${codeOf(rhsAssignmentAst.nodes.head)}", pattern.lineNumber, pattern.columnNumber ) - end createParamAst - - protected def astForDeconstruction( - pattern: BabelNodeInfo, - sourceAst: Ast, - code: String, - paramName: Option[String] = None - ): Ast = - val localTmpName = generateUnusedVariableName(usedVariableNames, "_tmp") - - val blockNode = createBlockNode(pattern, Option(code)) - scope.pushNewBlockScope(blockNode) - localAstParentStack.push(blockNode) - - val localNode = newLocalNode(localTmpName, Defines.Any).order(0) - val tmpNode = identifierNode(pattern, localTmpName) - diffGraph.addEdge(localAstParentStack.head, localNode, EdgeTypes.AST) - scope.addVariable(localTmpName, localNode, BlockScope) - scope.addVariableReference(localTmpName, tmpNode) - - val rhsAssignmentAst = - paramName.map(createParamAst(pattern, _, sourceAst)).getOrElse(sourceAst) - val assignmentTmpCallAst = - createAssignmentCallAst( - Ast(tmpNode), - rhsAssignmentAst, - s"$localTmpName = ${codeOf(rhsAssignmentAst.nodes.head)}", - pattern.lineNumber, - pattern.columnNumber - ) - - val subTreeAsts = pattern.node match - case ObjectPattern => - pattern.json("properties").arr.toList.map { element => - val nodeInfo = createBabelNodeInfo(element) - nodeInfo.node match - case RestElement => - val arg1Ast = Ast(identifierNode(nodeInfo, localTmpName)) - astForSpreadOrRestElement(nodeInfo, Option(arg1Ast)) - case _ => - val nodeInfo = createBabelNodeInfo(element("value")) - nodeInfo.node match - case Identifier => - convertDestructingObjectElement( - nodeInfo, - createBabelNodeInfo(element("key")), - localTmpName - ) - case AssignmentPattern => - convertDestructingObjectElementWithDefault( - nodeInfo, - createBabelNodeInfo(element("key")), - localTmpName - ) - case _ => astForNodeWithFunctionReference(nodeInfo.json) - end match - } - case ArrayPattern => - pattern.json("elements").arr.toList.zipWithIndex.map { - case (element, index) if !element.isNull => - val nodeInfo = createBabelNodeInfo(element) - nodeInfo.node match - case RestElement => - val fieldAccessTmpNode = identifierNode(nodeInfo, localTmpName) - val keyNode = - literalNode(nodeInfo, index.toString, Option(Defines.Number)) - val accessAst = - createIndexAccessCallAst( - fieldAccessTmpNode, - keyNode, - nodeInfo.lineNumber, - nodeInfo.columnNumber - ) - astForSpreadOrRestElement(nodeInfo, Option(accessAst)) - case Identifier => - convertDestructingArrayElement(nodeInfo, index, localTmpName) - case AssignmentPattern => - convertDestructingArrayElementWithDefault( - nodeInfo, - index, - localTmpName - ) - case _ => astForNodeWithFunctionReference(nodeInfo.json) - end match - case _ => Ast() - } - case _ => - List(convertDestructingObjectElement(pattern, pattern, localTmpName)) - - val returnTmpNode = identifierNode(pattern, localTmpName) - scope.popScope() - localAstParentStack.pop() - val blockChildren = assignmentTmpCallAst +: subTreeAsts :+ Ast(returnTmpNode) - setArgumentIndices(blockChildren) - blockAst(blockNode, blockChildren) - end astForDeconstruction + val subTreeAsts = pattern.node match + case ObjectPattern => + pattern.json("properties").arr.toList.map { element => + val nodeInfo = createBabelNodeInfo(element) + nodeInfo.node match + case RestElement => + val arg1Ast = Ast(identifierNode(nodeInfo, localTmpName)) + astForSpreadOrRestElement(nodeInfo, Option(arg1Ast)) + case _ => + val nodeInfo = createBabelNodeInfo(element("value")) + nodeInfo.node match + case Identifier => + convertDestructingObjectElement( + nodeInfo, + createBabelNodeInfo(element("key")), + localTmpName + ) + case AssignmentPattern => + convertDestructingObjectElementWithDefault( + nodeInfo, + createBabelNodeInfo(element("key")), + localTmpName + ) + case _ => astForNodeWithFunctionReference(nodeInfo.json) + end match + } + case ArrayPattern => + pattern.json("elements").arr.toList.zipWithIndex.map { + case (element, index) if !element.isNull => + val nodeInfo = createBabelNodeInfo(element) + nodeInfo.node match + case RestElement => + val fieldAccessTmpNode = identifierNode(nodeInfo, localTmpName) + val keyNode = + literalNode(nodeInfo, index.toString, Option(Defines.Number)) + val accessAst = + createIndexAccessCallAst( + fieldAccessTmpNode, + keyNode, + nodeInfo.lineNumber, + nodeInfo.columnNumber + ) + astForSpreadOrRestElement(nodeInfo, Option(accessAst)) + case Identifier => + convertDestructingArrayElement(nodeInfo, index, localTmpName) + case AssignmentPattern => + convertDestructingArrayElementWithDefault( + nodeInfo, + index, + localTmpName + ) + case _ => astForNodeWithFunctionReference(nodeInfo.json) + end match + case _ => Ast() + } + case _ => + List(convertDestructingObjectElement(pattern, pattern, localTmpName)) + + val returnTmpNode = identifierNode(pattern, localTmpName) + scope.popScope() + localAstParentStack.pop() + + val blockChildren = assignmentTmpCallAst +: subTreeAsts :+ Ast(returnTmpNode) + setArgumentIndices(blockChildren) + blockAst(blockNode, blockChildren) + end astForDeconstruction end AstForDeclarationsCreator diff --git a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/astcreation/AstForExpressionsCreator.scala b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/astcreation/AstForExpressionsCreator.scala index cc55acd8..401becc7 100644 --- a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/astcreation/AstForExpressionsCreator.scala +++ b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/astcreation/AstForExpressionsCreator.scala @@ -12,574 +12,574 @@ import io.shiftleft.codepropertygraph.generated.{DispatchTypes, EdgeTypes, Opera import scala.util.Try trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode): - this: AstCreator => - - protected def astForExpressionStatement(exprStmt: BabelNodeInfo): Ast = - astForNodeWithFunctionReference(exprStmt.json("expression")) - - private def createBuiltinStaticCall( - callExpr: BabelNodeInfo, - callee: BabelNodeInfo, - fullName: String - ): Ast = - val callName = callee.node match - case MemberExpression => code(callee.json("property")) - case _ => callee.code - val callNode = - createStaticCallNode( - callExpr.code, - callName, - fullName, - callee.lineNumber, - callee.columnNumber - ) - val argAsts = astForNodes(callExpr.json("arguments").arr.toList) - callAst(callNode, argAsts) - - private def handleCallNodeArgs( - callExpr: BabelNodeInfo, - receiverAst: Ast, - baseNode: NewNode, - callName: String - ): Ast = - val args = astForNodes(callExpr.json("arguments").arr.toList) - val callNode_ = callNode(callExpr, callExpr.code, callName, DispatchTypes.DYNAMIC_DISPATCH) - // If the callee is a function itself, e.g. closure, then resolve this locally, if possible - callExpr.json.obj - .get("callee") - .map(createBabelNodeInfo) - .flatMap { - case callee if callee.node.isInstanceOf[FunctionLike] => - functionNodeToNameAndFullName.get(callee) - case _ => None - } - .foreach { case (name, fullName) => callNode_.name(name).methodFullName(fullName) } - callAst(callNode_, args, receiver = Option(receiverAst), base = Option(Ast(baseNode))) - end handleCallNodeArgs - - protected def astForCallExpression(callExpr: BabelNodeInfo): Ast = - val callee = createBabelNodeInfo(callExpr.json("callee")) - val calleeCode = callee.code - if GlobalBuiltins.builtins.contains(calleeCode) then - createBuiltinStaticCall(callExpr, callee, calleeCode) - else - val (receiverAst, baseNode, callName) = callee.node match - case MemberExpression => - val base = createBabelNodeInfo(callee.json("object")) - val member = createBabelNodeInfo(callee.json("property")) - base.node match - case ThisExpression => - val receiverAst = astForNodeWithFunctionReference(callee.json) - val baseNode = identifierNode(base, base.code) - .dynamicTypeHintFullName(this.rootTypeDecl.map(_.fullName).toSeq) - scope.addVariableReference(base.code, baseNode) - (receiverAst, baseNode, member.code) - case Identifier => - val receiverAst = astForNodeWithFunctionReference(callee.json) - val baseNode = identifierNode(base, base.code) - scope.addVariableReference(base.code, baseNode) - (receiverAst, baseNode, member.code) - case _ => - val tmpVarName = generateUnusedVariableName(usedVariableNames, "_tmp") - val baseTmpNode = identifierNode(base, tmpVarName) - scope.addVariableReference(tmpVarName, baseTmpNode) - val baseAst = astForNodeWithFunctionReference(base.json) - val code = s"(${codeOf(baseTmpNode)} = ${base.code})" - val tmpAssignmentAst = - createAssignmentCallAst( - Ast(baseTmpNode), - baseAst, - code, - base.lineNumber, - base.columnNumber - ) - val memberNode = createFieldIdentifierNode( - member.code, - member.lineNumber, - member.columnNumber - ) - val fieldAccessAst = - createFieldAccessCallAst( - tmpAssignmentAst, - memberNode, - callee.lineNumber, - callee.columnNumber - ) - val thisTmpNode = identifierNode(callee, tmpVarName) - scope.addVariableReference(tmpVarName, thisTmpNode) - - (fieldAccessAst, thisTmpNode, member.code) - end match - case _ => - val receiverAst = astForNodeWithFunctionReference(callee.json) - val thisNode = identifierNode(callee, "this").dynamicTypeHintFullName( - typeHintForThisExpression() - ) - scope.addVariableReference(thisNode.name, thisNode) - (receiverAst, thisNode, calleeCode) - handleCallNodeArgs(callExpr, receiverAst, baseNode, callName) - end if - end astForCallExpression - - protected def astForThisExpression(thisExpr: BabelNodeInfo): Ast = - val dynamicTypeOption = typeHintForThisExpression(Option(thisExpr)).headOption - val thisNode = identifierNode(thisExpr, thisExpr.code, dynamicTypeOption.toList) - scope.addVariableReference(thisExpr.code, thisNode) - Ast(thisNode) - - protected def astForNewExpression(newExpr: BabelNodeInfo): Ast = - val callee = newExpr.json("callee") - val blockNode = createBlockNode(newExpr) - - scope.pushNewBlockScope(blockNode) - localAstParentStack.push(blockNode) - - val tmpAllocName = generateUnusedVariableName(usedVariableNames, "_tmp") - val localTmpAllocNode = newLocalNode(tmpAllocName, Defines.Any).order(0) - val tmpAllocNode1 = identifierNode(newExpr, tmpAllocName) - diffGraph.addEdge(localAstParentStack.head, localTmpAllocNode, EdgeTypes.AST) - scope.addVariableReference(tmpAllocName, tmpAllocNode1) - - val allocCallNode = - callNode(newExpr, ".alloc", Operators.alloc, DispatchTypes.STATIC_DISPATCH) - - val assignmentTmpAllocCallNode = - createAssignmentCallAst( - tmpAllocNode1, - allocCallNode, - s"$tmpAllocName = ${allocCallNode.code}", - newExpr.lineNumber, - newExpr.columnNumber - ) - - val tmpAllocNode2 = identifierNode(newExpr, tmpAllocName) - - val receiverNode = astForNodeWithFunctionReference(callee) - - // TODO: place ".new" into the schema - val callAst = handleCallNodeArgs(newExpr, receiverNode, tmpAllocNode2, ".new") - - val tmpAllocReturnNode = Ast(identifierNode(newExpr, tmpAllocName)) - - scope.popScope() - localAstParentStack.pop() + this: AstCreator => + + protected def astForExpressionStatement(exprStmt: BabelNodeInfo): Ast = + astForNodeWithFunctionReference(exprStmt.json("expression")) + + private def createBuiltinStaticCall( + callExpr: BabelNodeInfo, + callee: BabelNodeInfo, + fullName: String + ): Ast = + val callName = callee.node match + case MemberExpression => code(callee.json("property")) + case _ => callee.code + val callNode = + createStaticCallNode( + callExpr.code, + callName, + fullName, + callee.lineNumber, + callee.columnNumber + ) + val argAsts = astForNodes(callExpr.json("arguments").arr.toList) + callAst(callNode, argAsts) + + private def handleCallNodeArgs( + callExpr: BabelNodeInfo, + receiverAst: Ast, + baseNode: NewNode, + callName: String + ): Ast = + val args = astForNodes(callExpr.json("arguments").arr.toList) + val callNode_ = callNode(callExpr, callExpr.code, callName, DispatchTypes.DYNAMIC_DISPATCH) + // If the callee is a function itself, e.g. closure, then resolve this locally, if possible + callExpr.json.obj + .get("callee") + .map(createBabelNodeInfo) + .flatMap { + case callee if callee.node.isInstanceOf[FunctionLike] => + functionNodeToNameAndFullName.get(callee) + case _ => None + } + .foreach { case (name, fullName) => callNode_.name(name).methodFullName(fullName) } + callAst(callNode_, args, receiver = Option(receiverAst), base = Option(Ast(baseNode))) + end handleCallNodeArgs + + protected def astForCallExpression(callExpr: BabelNodeInfo): Ast = + val callee = createBabelNodeInfo(callExpr.json("callee")) + val calleeCode = callee.code + if GlobalBuiltins.builtins.contains(calleeCode) then + createBuiltinStaticCall(callExpr, callee, calleeCode) + else + val (receiverAst, baseNode, callName) = callee.node match + case MemberExpression => + val base = createBabelNodeInfo(callee.json("object")) + val member = createBabelNodeInfo(callee.json("property")) + base.node match + case ThisExpression => + val receiverAst = astForNodeWithFunctionReference(callee.json) + val baseNode = identifierNode(base, base.code) + .dynamicTypeHintFullName(this.rootTypeDecl.map(_.fullName).toSeq) + scope.addVariableReference(base.code, baseNode) + (receiverAst, baseNode, member.code) + case Identifier => + val receiverAst = astForNodeWithFunctionReference(callee.json) + val baseNode = identifierNode(base, base.code) + scope.addVariableReference(base.code, baseNode) + (receiverAst, baseNode, member.code) + case _ => + val tmpVarName = generateUnusedVariableName(usedVariableNames, "_tmp") + val baseTmpNode = identifierNode(base, tmpVarName) + scope.addVariableReference(tmpVarName, baseTmpNode) + val baseAst = astForNodeWithFunctionReference(base.json) + val code = s"(${codeOf(baseTmpNode)} = ${base.code})" + val tmpAssignmentAst = + createAssignmentCallAst( + Ast(baseTmpNode), + baseAst, + code, + base.lineNumber, + base.columnNumber + ) + val memberNode = createFieldIdentifierNode( + member.code, + member.lineNumber, + member.columnNumber + ) + val fieldAccessAst = + createFieldAccessCallAst( + tmpAssignmentAst, + memberNode, + callee.lineNumber, + callee.columnNumber + ) + val thisTmpNode = identifierNode(callee, tmpVarName) + scope.addVariableReference(tmpVarName, thisTmpNode) - val blockChildren = List(assignmentTmpAllocCallNode, callAst, tmpAllocReturnNode) - setArgumentIndices(blockChildren) - blockAst(blockNode, blockChildren) - end astForNewExpression + (fieldAccessAst, thisTmpNode, member.code) + end match + case _ => + val receiverAst = astForNodeWithFunctionReference(callee.json) + val thisNode = identifierNode(callee, "this").dynamicTypeHintFullName( + typeHintForThisExpression() + ) + scope.addVariableReference(thisNode.name, thisNode) + (receiverAst, thisNode, calleeCode) + handleCallNodeArgs(callExpr, receiverAst, baseNode, callName) + end if + end astForCallExpression + + protected def astForThisExpression(thisExpr: BabelNodeInfo): Ast = + val dynamicTypeOption = typeHintForThisExpression(Option(thisExpr)).headOption + val thisNode = identifierNode(thisExpr, thisExpr.code, dynamicTypeOption.toList) + scope.addVariableReference(thisExpr.code, thisNode) + Ast(thisNode) + + protected def astForNewExpression(newExpr: BabelNodeInfo): Ast = + val callee = newExpr.json("callee") + val blockNode = createBlockNode(newExpr) + + scope.pushNewBlockScope(blockNode) + localAstParentStack.push(blockNode) + + val tmpAllocName = generateUnusedVariableName(usedVariableNames, "_tmp") + val localTmpAllocNode = newLocalNode(tmpAllocName, Defines.Any).order(0) + val tmpAllocNode1 = identifierNode(newExpr, tmpAllocName) + diffGraph.addEdge(localAstParentStack.head, localTmpAllocNode, EdgeTypes.AST) + scope.addVariableReference(tmpAllocName, tmpAllocNode1) + + val allocCallNode = + callNode(newExpr, ".alloc", Operators.alloc, DispatchTypes.STATIC_DISPATCH) + + val assignmentTmpAllocCallNode = + createAssignmentCallAst( + tmpAllocNode1, + allocCallNode, + s"$tmpAllocName = ${allocCallNode.code}", + newExpr.lineNumber, + newExpr.columnNumber + ) - protected def astForMetaProperty(metaProperty: BabelNodeInfo): Ast = - val metaAst = astForIdentifier(createBabelNodeInfo(metaProperty.json("meta"))) - val memberNodeInfo = createBabelNodeInfo(metaProperty.json("property")) - val memberAst = Ast( - createFieldIdentifierNode( - memberNodeInfo.code, - memberNodeInfo.lineNumber, - memberNodeInfo.columnNumber - ) + val tmpAllocNode2 = identifierNode(newExpr, tmpAllocName) + + val receiverNode = astForNodeWithFunctionReference(callee) + + // TODO: place ".new" into the schema + val callAst = handleCallNodeArgs(newExpr, receiverNode, tmpAllocNode2, ".new") + + val tmpAllocReturnNode = Ast(identifierNode(newExpr, tmpAllocName)) + + scope.popScope() + localAstParentStack.pop() + + val blockChildren = List(assignmentTmpAllocCallNode, callAst, tmpAllocReturnNode) + setArgumentIndices(blockChildren) + blockAst(blockNode, blockChildren) + end astForNewExpression + + protected def astForMetaProperty(metaProperty: BabelNodeInfo): Ast = + val metaAst = astForIdentifier(createBabelNodeInfo(metaProperty.json("meta"))) + val memberNodeInfo = createBabelNodeInfo(metaProperty.json("property")) + val memberAst = Ast( + createFieldIdentifierNode( + memberNodeInfo.code, + memberNodeInfo.lineNumber, + memberNodeInfo.columnNumber + ) + ) + createFieldAccessCallAst( + metaAst, + memberAst.nodes.head, + metaProperty.lineNumber, + metaProperty.columnNumber + ) + + protected def astForMemberExpression(memberExpr: BabelNodeInfo): Ast = + val baseAst = astForNodeWithFunctionReference(memberExpr.json("object")) + val memberIsComputed = memberExpr.json("computed").bool + val memberNodeInfo = createBabelNodeInfo(memberExpr.json("property")) + if memberIsComputed then + val memberAst = astForNode(memberNodeInfo.json) + createIndexAccessCallAst( + baseAst, + memberAst, + memberExpr.lineNumber, + memberExpr.columnNumber + ) + else + val memberAst = Ast( + createFieldIdentifierNode( + memberNodeInfo.code, + memberNodeInfo.lineNumber, + memberNodeInfo.columnNumber ) - createFieldAccessCallAst( - metaAst, - memberAst.nodes.head, - metaProperty.lineNumber, - metaProperty.columnNumber + ) + createFieldAccessCallAst( + baseAst, + memberAst.nodes.head, + memberExpr.lineNumber, + memberExpr.columnNumber + ) + end if + end astForMemberExpression + + protected def astForAssignmentExpression(assignment: BabelNodeInfo): Ast = + val op = if hasKey(assignment.json, "operator") then + assignment.json("operator").str match + case "=" => Operators.assignment + case "+=" => Operators.assignmentPlus + case "-=" => Operators.assignmentMinus + case "*=" => Operators.assignmentMultiplication + case "/=" => Operators.assignmentDivision + case "%=" => Operators.assignmentModulo + case "**=" => Operators.assignmentExponentiation + case "&=" => Operators.assignmentAnd + case "&&=" => Operators.assignmentAnd + case "|=" => Operators.assignmentOr + case "||=" => Operators.assignmentOr + case "^=" => Operators.assignmentXor + case "<<=" => Operators.assignmentShiftLeft + case ">>=" => Operators.assignmentArithmeticShiftRight + case ">>>=" => Operators.assignmentLogicalShiftRight + case "??=" => Operators.notNullAssert + case other => + logger.warn(s"Unknown assignment operator: '$other'") + Operators.assignment + else Operators.assignment + + val nodeInfo = createBabelNodeInfo(assignment.json("left")) + nodeInfo.node match + case ObjectPattern | ArrayPattern => + val rhsAst = astForNodeWithFunctionReference(assignment.json("right")) + astForDeconstruction(nodeInfo, rhsAst, assignment.code) + case _ => + val lhsAst = astForNode(assignment.json("left")) + val rhsAst = astForNodeWithFunctionReference(assignment.json("right")) + val callNode_ = + callNode(assignment, assignment.code, op, DispatchTypes.STATIC_DISPATCH) + val argAsts = List(lhsAst, rhsAst) + callAst(callNode_, argAsts) + end astForAssignmentExpression + + protected def astForConditionalExpression(ternary: BabelNodeInfo): Ast = + val testAst = astForNodeWithFunctionReference(ternary.json("test")) + val consequentAst = astForNodeWithFunctionReference(ternary.json("consequent")) + val alternateAst = astForNodeWithFunctionReference(ternary.json("alternate")) + createTernaryCallAst( + testAst, + consequentAst, + alternateAst, + ternary.lineNumber, + ternary.columnNumber + ) + + protected def astForLogicalExpression(logicalExpr: BabelNodeInfo): Ast = + astForBinaryExpression(logicalExpr) + + protected def astForTSNonNullExpression(nonNullExpr: BabelNodeInfo): Ast = + val op = Operators.notNullAssert + val callNode_ = + callNode(nonNullExpr, nonNullExpr.code, op, DispatchTypes.STATIC_DISPATCH) + val argAsts = List(astForNodeWithFunctionReference(nonNullExpr.json("expression"))) + callAst(callNode_, argAsts) + + protected def astForCastExpression(castExpr: BabelNodeInfo): Ast = + val op = Operators.cast + val typ = typeFor(castExpr) match + case t if GlobalBuiltins.builtins.contains(t) => s"__ecma.$t" + case t => t + val lhsNode = castExpr.json("typeAnnotation") + val lhsAst = + Ast(literalNode(castExpr, code(lhsNode), None).dynamicTypeHintFullName(Seq(typ))) + val rhsAst = astForNodeWithFunctionReference(castExpr.json("expression")) + val node = + callNode(castExpr, castExpr.code, op, DispatchTypes.STATIC_DISPATCH) + .dynamicTypeHintFullName(Seq(typ)) + val argAsts = List(lhsAst, rhsAst) + callAst(node, argAsts) + + protected def astForBinaryExpression(binExpr: BabelNodeInfo): Ast = + val op = binExpr.json("operator").str match + case "+" => Operators.addition + case "-" => Operators.subtraction + case "/" => Operators.division + case "%" => Operators.modulo + case "*" => Operators.multiplication + case "**" => Operators.exponentiation + case "&" => Operators.and + case ">>" => Operators.arithmeticShiftRight + case ">>>" => Operators.arithmeticShiftRight + case "<<" => Operators.shiftLeft + case "^" => Operators.xor + case "==" => Operators.equals + case "===" => Operators.equals + case "!=" => Operators.notEquals + case "!==" => Operators.notEquals + case "in" => Operators.in + case ">" => Operators.greaterThan + case "<" => Operators.lessThan + case ">=" => Operators.greaterEqualsThan + case "<=" => Operators.lessEqualsThan + case "instanceof" => Operators.instanceOf + case "||" => Operators.logicalOr + case "|" => Operators.or + case "&&" => Operators.logicalAnd + // special case (see: https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Operators/Nullish_coalescing_operator) + case "??" => Operators.logicalOr + case "case" => ".case" + case other => + logger.warn(s"Unknown binary operator: '$other'") + Operators.assignment + + val lhsAst = astForNodeWithFunctionReference(binExpr.json("left")) + val rhsAst = astForNodeWithFunctionReference(binExpr.json("right")) + + val node = + callNode(binExpr, binExpr.code, op, DispatchTypes.STATIC_DISPATCH) + val argAsts = List(lhsAst, rhsAst) + callAst(node, argAsts) + end astForBinaryExpression + + protected def astForUpdateExpression(updateExpr: BabelNodeInfo): Ast = + val isPrefix = updateExpr.json("prefix").bool + val op = updateExpr.json("operator").str match + case "++" if isPrefix => Operators.preIncrement + case "++" => Operators.postIncrement + case "--" if isPrefix => Operators.preIncrement + case "--" => Operators.postIncrement + case other => + logger.warn(s"Unknown update operator: '$other'") + Operators.assignment + + val argumentAst = astForNodeWithFunctionReference(updateExpr.json("argument")) + + val node = callNode(updateExpr, updateExpr.code, op, DispatchTypes.STATIC_DISPATCH) + val argAsts = List(argumentAst) + callAst(node, argAsts) + + protected def astForUnaryExpression(unaryExpr: BabelNodeInfo): Ast = + val op = unaryExpr.json("operator").str match + case "void" => ".void" + case "throw" => ".throw" + case "delete" => Operators.delete + case "!" => Operators.logicalNot + case "+" => Operators.plus + case "-" => Operators.minus + case "~" => ".bitNot" + case "typeof" => Operators.instanceOf + case other => + logger.warn(s"Unknown update operator: '$other'") + Operators.assignment + + val argumentAst = astForNodeWithFunctionReference(unaryExpr.json("argument")) + + val node = callNode(unaryExpr, unaryExpr.code, op, DispatchTypes.STATIC_DISPATCH) + val argAsts = List(argumentAst) + callAst(node, argAsts) + end astForUnaryExpression + + protected def astForSequenceExpression(seq: BabelNodeInfo): Ast = + val blockNode = createBlockNode(seq) + scope.pushNewBlockScope(blockNode) + localAstParentStack.push(blockNode) + val sequenceExpressionAsts = createBlockStatementAsts(seq.json("expressions")) + setArgumentIndices(sequenceExpressionAsts) + localAstParentStack.pop() + scope.popScope() + blockAst(blockNode, sequenceExpressionAsts) + + protected def astForAwaitExpression(awaitExpr: BabelNodeInfo): Ast = + val node = + callNode(awaitExpr, awaitExpr.code, ".await", DispatchTypes.STATIC_DISPATCH) + val argAsts = List(astForNodeWithFunctionReference(awaitExpr.json("argument"))) + callAst(node, argAsts) + + protected def astForArrayExpression(arrExpr: BabelNodeInfo): Ast = + val elements = Try(arrExpr.json("elements").arr).toOption.toList.flatten + if elements.isEmpty then + Ast( + callNode( + arrExpr, + s"${EcmaBuiltins.arrayFactory}()", + EcmaBuiltins.arrayFactory, + DispatchTypes.STATIC_DISPATCH ) + ) + else + val blockNode = createBlockNode(arrExpr) + scope.pushNewBlockScope(blockNode) + localAstParentStack.push(blockNode) + + val tmpName = generateUnusedVariableName(usedVariableNames, "_tmp") + val localTmpNode = newLocalNode(tmpName, Defines.Any).order(0) + val tmpArrayNode = identifierNode(arrExpr, tmpName) + diffGraph.addEdge(localAstParentStack.head, localTmpNode, EdgeTypes.AST) + scope.addVariableReference(tmpName, tmpArrayNode) + + val arrayCallNode = + callNode( + arrExpr, + s"${EcmaBuiltins.arrayFactory}()", + EcmaBuiltins.arrayFactory, + DispatchTypes.STATIC_DISPATCH + ) - protected def astForMemberExpression(memberExpr: BabelNodeInfo): Ast = - val baseAst = astForNodeWithFunctionReference(memberExpr.json("object")) - val memberIsComputed = memberExpr.json("computed").bool - val memberNodeInfo = createBabelNodeInfo(memberExpr.json("property")) - if memberIsComputed then - val memberAst = astForNode(memberNodeInfo.json) - createIndexAccessCallAst( - baseAst, - memberAst, - memberExpr.lineNumber, - memberExpr.columnNumber - ) - else - val memberAst = Ast( - createFieldIdentifierNode( - memberNodeInfo.code, - memberNodeInfo.lineNumber, - memberNodeInfo.columnNumber - ) - ) - createFieldAccessCallAst( - baseAst, - memberAst.nodes.head, - memberExpr.lineNumber, - memberExpr.columnNumber - ) - end if - end astForMemberExpression - - protected def astForAssignmentExpression(assignment: BabelNodeInfo): Ast = - val op = if hasKey(assignment.json, "operator") then - assignment.json("operator").str match - case "=" => Operators.assignment - case "+=" => Operators.assignmentPlus - case "-=" => Operators.assignmentMinus - case "*=" => Operators.assignmentMultiplication - case "/=" => Operators.assignmentDivision - case "%=" => Operators.assignmentModulo - case "**=" => Operators.assignmentExponentiation - case "&=" => Operators.assignmentAnd - case "&&=" => Operators.assignmentAnd - case "|=" => Operators.assignmentOr - case "||=" => Operators.assignmentOr - case "^=" => Operators.assignmentXor - case "<<=" => Operators.assignmentShiftLeft - case ">>=" => Operators.assignmentArithmeticShiftRight - case ">>>=" => Operators.assignmentLogicalShiftRight - case "??=" => Operators.notNullAssert - case other => - logger.warn(s"Unknown assignment operator: '$other'") - Operators.assignment - else Operators.assignment - - val nodeInfo = createBabelNodeInfo(assignment.json("left")) - nodeInfo.node match - case ObjectPattern | ArrayPattern => - val rhsAst = astForNodeWithFunctionReference(assignment.json("right")) - astForDeconstruction(nodeInfo, rhsAst, assignment.code) - case _ => - val lhsAst = astForNode(assignment.json("left")) - val rhsAst = astForNodeWithFunctionReference(assignment.json("right")) - val callNode_ = - callNode(assignment, assignment.code, op, DispatchTypes.STATIC_DISPATCH) - val argAsts = List(lhsAst, rhsAst) - callAst(callNode_, argAsts) - end astForAssignmentExpression - - protected def astForConditionalExpression(ternary: BabelNodeInfo): Ast = - val testAst = astForNodeWithFunctionReference(ternary.json("test")) - val consequentAst = astForNodeWithFunctionReference(ternary.json("consequent")) - val alternateAst = astForNodeWithFunctionReference(ternary.json("alternate")) - createTernaryCallAst( - testAst, - consequentAst, - alternateAst, - ternary.lineNumber, - ternary.columnNumber - ) + val lineNumber = arrExpr.lineNumber + val columnNumber = arrExpr.columnNumber + val assignmentCode = s"${localTmpNode.code} = ${arrayCallNode.code}" + val assignmentTmpArrayCallNode = + createAssignmentCallAst( + tmpArrayNode, + arrayCallNode, + assignmentCode, + lineNumber, + columnNumber + ) - protected def astForLogicalExpression(logicalExpr: BabelNodeInfo): Ast = - astForBinaryExpression(logicalExpr) - - protected def astForTSNonNullExpression(nonNullExpr: BabelNodeInfo): Ast = - val op = Operators.notNullAssert - val callNode_ = - callNode(nonNullExpr, nonNullExpr.code, op, DispatchTypes.STATIC_DISPATCH) - val argAsts = List(astForNodeWithFunctionReference(nonNullExpr.json("expression"))) - callAst(callNode_, argAsts) - - protected def astForCastExpression(castExpr: BabelNodeInfo): Ast = - val op = Operators.cast - val typ = typeFor(castExpr) match - case t if GlobalBuiltins.builtins.contains(t) => s"__ecma.$t" - case t => t - val lhsNode = castExpr.json("typeAnnotation") - val lhsAst = - Ast(literalNode(castExpr, code(lhsNode), None).dynamicTypeHintFullName(Seq(typ))) - val rhsAst = astForNodeWithFunctionReference(castExpr.json("expression")) - val node = - callNode(castExpr, castExpr.code, op, DispatchTypes.STATIC_DISPATCH) - .dynamicTypeHintFullName(Seq(typ)) - val argAsts = List(lhsAst, rhsAst) - callAst(node, argAsts) - - protected def astForBinaryExpression(binExpr: BabelNodeInfo): Ast = - val op = binExpr.json("operator").str match - case "+" => Operators.addition - case "-" => Operators.subtraction - case "/" => Operators.division - case "%" => Operators.modulo - case "*" => Operators.multiplication - case "**" => Operators.exponentiation - case "&" => Operators.and - case ">>" => Operators.arithmeticShiftRight - case ">>>" => Operators.arithmeticShiftRight - case "<<" => Operators.shiftLeft - case "^" => Operators.xor - case "==" => Operators.equals - case "===" => Operators.equals - case "!=" => Operators.notEquals - case "!==" => Operators.notEquals - case "in" => Operators.in - case ">" => Operators.greaterThan - case "<" => Operators.lessThan - case ">=" => Operators.greaterEqualsThan - case "<=" => Operators.lessEqualsThan - case "instanceof" => Operators.instanceOf - case "||" => Operators.logicalOr - case "|" => Operators.or - case "&&" => Operators.logicalAnd - // special case (see: https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Operators/Nullish_coalescing_operator) - case "??" => Operators.logicalOr - case "case" => ".case" - case other => - logger.warn(s"Unknown binary operator: '$other'") - Operators.assignment - - val lhsAst = astForNodeWithFunctionReference(binExpr.json("left")) - val rhsAst = astForNodeWithFunctionReference(binExpr.json("right")) - - val node = - callNode(binExpr, binExpr.code, op, DispatchTypes.STATIC_DISPATCH) - val argAsts = List(lhsAst, rhsAst) - callAst(node, argAsts) - end astForBinaryExpression - - protected def astForUpdateExpression(updateExpr: BabelNodeInfo): Ast = - val isPrefix = updateExpr.json("prefix").bool - val op = updateExpr.json("operator").str match - case "++" if isPrefix => Operators.preIncrement - case "++" => Operators.postIncrement - case "--" if isPrefix => Operators.preIncrement - case "--" => Operators.postIncrement - case other => - logger.warn(s"Unknown update operator: '$other'") - Operators.assignment - - val argumentAst = astForNodeWithFunctionReference(updateExpr.json("argument")) - - val node = callNode(updateExpr, updateExpr.code, op, DispatchTypes.STATIC_DISPATCH) - val argAsts = List(argumentAst) - callAst(node, argAsts) - - protected def astForUnaryExpression(unaryExpr: BabelNodeInfo): Ast = - val op = unaryExpr.json("operator").str match - case "void" => ".void" - case "throw" => ".throw" - case "delete" => Operators.delete - case "!" => Operators.logicalNot - case "+" => Operators.plus - case "-" => Operators.minus - case "~" => ".bitNot" - case "typeof" => Operators.instanceOf - case other => - logger.warn(s"Unknown update operator: '$other'") - Operators.assignment - - val argumentAst = astForNodeWithFunctionReference(unaryExpr.json("argument")) - - val node = callNode(unaryExpr, unaryExpr.code, op, DispatchTypes.STATIC_DISPATCH) - val argAsts = List(argumentAst) - callAst(node, argAsts) - end astForUnaryExpression - - protected def astForSequenceExpression(seq: BabelNodeInfo): Ast = - val blockNode = createBlockNode(seq) - scope.pushNewBlockScope(blockNode) - localAstParentStack.push(blockNode) - val sequenceExpressionAsts = createBlockStatementAsts(seq.json("expressions")) - setArgumentIndices(sequenceExpressionAsts) - localAstParentStack.pop() - scope.popScope() - blockAst(blockNode, sequenceExpressionAsts) - - protected def astForAwaitExpression(awaitExpr: BabelNodeInfo): Ast = - val node = - callNode(awaitExpr, awaitExpr.code, ".await", DispatchTypes.STATIC_DISPATCH) - val argAsts = List(astForNodeWithFunctionReference(awaitExpr.json("argument"))) - callAst(node, argAsts) - - protected def astForArrayExpression(arrExpr: BabelNodeInfo): Ast = - val elements = Try(arrExpr.json("elements").arr).toOption.toList.flatten - if elements.isEmpty then - Ast( - callNode( - arrExpr, - s"${EcmaBuiltins.arrayFactory}()", - EcmaBuiltins.arrayFactory, - DispatchTypes.STATIC_DISPATCH + val elementAsts = elements.flatMap { + case element if !element.isNull => + val elementNodeInfo = createBabelNodeInfo(element) + val elementLineNumber = elementNodeInfo.lineNumber + val elementColumnNumber = elementNodeInfo.columnNumber + val elementCode = elementNodeInfo.code + val elementNode = elementNodeInfo.node match + case RestElement => + val arg1Ast = Ast(identifierNode(arrExpr, tmpName)) + astForSpreadOrRestElement(elementNodeInfo, Option(arg1Ast)) + case _ => + astForNodeWithFunctionReference(element) + + val pushCallNode = + callNode( + elementNodeInfo, + s"$tmpName.push($elementCode)", + "", + DispatchTypes.DYNAMIC_DISPATCH + ) + + val baseNode = identifierNode(elementNodeInfo, tmpName) + val memberNode = + createFieldIdentifierNode("push", elementLineNumber, elementColumnNumber) + val receiverNode = createFieldAccessCallAst( + baseNode, + memberNode, + elementLineNumber, + elementColumnNumber ) - ) - else - val blockNode = createBlockNode(arrExpr) - scope.pushNewBlockScope(blockNode) - localAstParentStack.push(blockNode) - - val tmpName = generateUnusedVariableName(usedVariableNames, "_tmp") - val localTmpNode = newLocalNode(tmpName, Defines.Any).order(0) - val tmpArrayNode = identifierNode(arrExpr, tmpName) - diffGraph.addEdge(localAstParentStack.head, localTmpNode, EdgeTypes.AST) - scope.addVariableReference(tmpName, tmpArrayNode) - - val arrayCallNode = - callNode( - arrExpr, - s"${EcmaBuiltins.arrayFactory}()", - EcmaBuiltins.arrayFactory, - DispatchTypes.STATIC_DISPATCH + val thisPushNode = identifierNode(elementNodeInfo, tmpName) + + Option( + callAst( + pushCallNode, + List(elementNode), + receiver = Option(receiverNode), + base = Option(Ast(thisPushNode)) ) - - val lineNumber = arrExpr.lineNumber - val columnNumber = arrExpr.columnNumber - val assignmentCode = s"${localTmpNode.code} = ${arrayCallNode.code}" - val assignmentTmpArrayCallNode = - createAssignmentCallAst( - tmpArrayNode, - arrayCallNode, - assignmentCode, - lineNumber, - columnNumber - ) - - val elementAsts = elements.flatMap { - case element if !element.isNull => - val elementNodeInfo = createBabelNodeInfo(element) - val elementLineNumber = elementNodeInfo.lineNumber - val elementColumnNumber = elementNodeInfo.columnNumber - val elementCode = elementNodeInfo.code - val elementNode = elementNodeInfo.node match - case RestElement => - val arg1Ast = Ast(identifierNode(arrExpr, tmpName)) - astForSpreadOrRestElement(elementNodeInfo, Option(arg1Ast)) - case _ => - astForNodeWithFunctionReference(element) - - val pushCallNode = - callNode( - elementNodeInfo, - s"$tmpName.push($elementCode)", - "", - DispatchTypes.DYNAMIC_DISPATCH - ) - - val baseNode = identifierNode(elementNodeInfo, tmpName) - val memberNode = - createFieldIdentifierNode("push", elementLineNumber, elementColumnNumber) - val receiverNode = createFieldAccessCallAst( - baseNode, - memberNode, - elementLineNumber, - elementColumnNumber - ) - val thisPushNode = identifierNode(elementNodeInfo, tmpName) - - Option( - callAst( - pushCallNode, - List(elementNode), - receiver = Option(receiverNode), - base = Option(Ast(thisPushNode)) - ) + ) + case _ => None // skip + } + + val tmpArrayReturnNode = identifierNode(arrExpr, tmpName) + + scope.popScope() + localAstParentStack.pop() + + val blockChildrenAsts = + assignmentTmpArrayCallNode +: elementAsts :+ Ast(tmpArrayReturnNode) + setArgumentIndices(blockChildrenAsts) + blockAst(blockNode, blockChildrenAsts) + end if + end astForArrayExpression + + def astForTemplateExpression(templateExpr: BabelNodeInfo): Ast = + val argumentAst = astForNodeWithFunctionReference(templateExpr.json("quasi")) + val callName = code(templateExpr.json("tag")) + val callCode = s"$callName(${codeOf(argumentAst.nodes.head)})" + val templateExprCall = + callNode(templateExpr, callCode, callName, DispatchTypes.STATIC_DISPATCH) + val argAsts = List(argumentAst) + callAst(templateExprCall, argAsts) + + protected def astForObjectExpression(objExpr: BabelNodeInfo): Ast = + val blockNode = createBlockNode(objExpr) + + scope.pushNewBlockScope(blockNode) + localAstParentStack.push(blockNode) + + val tmpName = generateUnusedVariableName(usedVariableNames, "_tmp") + val localNode = newLocalNode(tmpName, Defines.Any).order(0) + diffGraph.addEdge(localAstParentStack.head, localNode, EdgeTypes.AST) + + val propertiesAsts = objExpr.json("properties").arr.toList.map { property => + val nodeInfo = createBabelNodeInfo(property) + nodeInfo.node match + case SpreadElement | RestElement => + val arg1Ast = Ast(identifierNode(nodeInfo, tmpName)) + astForSpreadOrRestElement(nodeInfo, Option(arg1Ast)) + case _ => + val (lhsNode, rhsAst) = nodeInfo.node match + case ObjectMethod => + val keyName = + if hasKey(nodeInfo.json("key"), "name") then + nodeInfo.json("key")("name").str + else code(nodeInfo.json("key")) + val keyNode = createFieldIdentifierNode( + keyName, + nodeInfo.lineNumber, + nodeInfo.columnNumber + ) + ( + keyNode, + astForFunctionDeclaration( + nodeInfo, + shouldCreateFunctionReference = true ) - case _ => None // skip - } - - val tmpArrayReturnNode = identifierNode(arrExpr, tmpName) - - scope.popScope() - localAstParentStack.pop() - - val blockChildrenAsts = - assignmentTmpArrayCallNode +: elementAsts :+ Ast(tmpArrayReturnNode) - setArgumentIndices(blockChildrenAsts) - blockAst(blockNode, blockChildrenAsts) - end if - end astForArrayExpression - - def astForTemplateExpression(templateExpr: BabelNodeInfo): Ast = - val argumentAst = astForNodeWithFunctionReference(templateExpr.json("quasi")) - val callName = code(templateExpr.json("tag")) - val callCode = s"$callName(${codeOf(argumentAst.nodes.head)})" - val templateExprCall = - callNode(templateExpr, callCode, callName, DispatchTypes.STATIC_DISPATCH) - val argAsts = List(argumentAst) - callAst(templateExprCall, argAsts) - - protected def astForObjectExpression(objExpr: BabelNodeInfo): Ast = - val blockNode = createBlockNode(objExpr) - - scope.pushNewBlockScope(blockNode) - localAstParentStack.push(blockNode) - - val tmpName = generateUnusedVariableName(usedVariableNames, "_tmp") - val localNode = newLocalNode(tmpName, Defines.Any).order(0) - diffGraph.addEdge(localAstParentStack.head, localNode, EdgeTypes.AST) - - val propertiesAsts = objExpr.json("properties").arr.toList.map { property => - val nodeInfo = createBabelNodeInfo(property) - nodeInfo.node match - case SpreadElement | RestElement => - val arg1Ast = Ast(identifierNode(nodeInfo, tmpName)) - astForSpreadOrRestElement(nodeInfo, Option(arg1Ast)) - case _ => - val (lhsNode, rhsAst) = nodeInfo.node match - case ObjectMethod => - val keyName = - if hasKey(nodeInfo.json("key"), "name") then - nodeInfo.json("key")("name").str - else code(nodeInfo.json("key")) - val keyNode = createFieldIdentifierNode( - keyName, - nodeInfo.lineNumber, - nodeInfo.columnNumber - ) - ( - keyNode, - astForFunctionDeclaration( - nodeInfo, - shouldCreateFunctionReference = true - ) - ) - case ObjectProperty => - val key = createBabelNodeInfo(nodeInfo.json("key")) - val keyName = key.node match - case Identifier if nodeInfo.json("computed").bool => - key.code - case _ if nodeInfo.json("computed").bool => - generateUnusedVariableName( - usedVariableNames, - "_computed_object_property" - ) - case _ => key.code - val keyNode = createFieldIdentifierNode( - keyName, - nodeInfo.lineNumber, - nodeInfo.columnNumber - ) - val ast = astForNodeWithFunctionReference(nodeInfo.json("value")) - (keyNode, ast) - case _ => - // can't happen as per https://github.com/babel/babel/blob/main/packages/babel-types/src/ast-types/generated/index.ts#L573 - // just to make the compiler happy here. - ??? - - val leftHandSideTmpNode = identifierNode(nodeInfo, tmpName) - val leftHandSideFieldAccessAst = - createFieldAccessCallAst( - leftHandSideTmpNode, - lhsNode, - nodeInfo.lineNumber, - nodeInfo.columnNumber + ) + case ObjectProperty => + val key = createBabelNodeInfo(nodeInfo.json("key")) + val keyName = key.node match + case Identifier if nodeInfo.json("computed").bool => + key.code + case _ if nodeInfo.json("computed").bool => + generateUnusedVariableName( + usedVariableNames, + "_computed_object_property" ) + case _ => key.code + val keyNode = createFieldIdentifierNode( + keyName, + nodeInfo.lineNumber, + nodeInfo.columnNumber + ) + val ast = astForNodeWithFunctionReference(nodeInfo.json("value")) + (keyNode, ast) + case _ => + // can't happen as per https://github.com/babel/babel/blob/main/packages/babel-types/src/ast-types/generated/index.ts#L573 + // just to make the compiler happy here. + ??? + + val leftHandSideTmpNode = identifierNode(nodeInfo, tmpName) + val leftHandSideFieldAccessAst = + createFieldAccessCallAst( + leftHandSideTmpNode, + lhsNode, + nodeInfo.lineNumber, + nodeInfo.columnNumber + ) - createAssignmentCallAst( - leftHandSideFieldAccessAst, - rhsAst, - s"$tmpName.${lhsNode.canonicalName} = ${codeOf(rhsAst.nodes.head)}", - nodeInfo.lineNumber, - nodeInfo.columnNumber - ) - end match - } + createAssignmentCallAst( + leftHandSideFieldAccessAst, + rhsAst, + s"$tmpName.${lhsNode.canonicalName} = ${codeOf(rhsAst.nodes.head)}", + nodeInfo.lineNumber, + nodeInfo.columnNumber + ) + end match + } - val tmpNode = identifierNode(objExpr, tmpName) + val tmpNode = identifierNode(objExpr, tmpName) - scope.popScope() - localAstParentStack.pop() + scope.popScope() + localAstParentStack.pop() - val childrenAsts = propertiesAsts :+ Ast(tmpNode) - setArgumentIndices(childrenAsts) - blockAst(blockNode, childrenAsts) - end astForObjectExpression + val childrenAsts = propertiesAsts :+ Ast(tmpNode) + setArgumentIndices(childrenAsts) + blockAst(blockNode, childrenAsts) + end astForObjectExpression - protected def astForTSSatisfiesExpression(satisfiesExpr: BabelNodeInfo): Ast = - // Ignores the type, i.e. `x satisfies T` is understood as `x`. - astForNode(satisfiesExpr.json("expression")) + protected def astForTSSatisfiesExpression(satisfiesExpr: BabelNodeInfo): Ast = + // Ignores the type, i.e. `x satisfies T` is understood as `x`. + astForNode(satisfiesExpr.json("expression")) end AstForExpressionsCreator diff --git a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/astcreation/AstForFunctionsCreator.scala b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/astcreation/AstForFunctionsCreator.scala index ee435004..c399269a 100644 --- a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/astcreation/AstForFunctionsCreator.scala +++ b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/astcreation/AstForFunctionsCreator.scala @@ -19,548 +19,548 @@ import scala.collection.mutable import scala.util.Try trait AstForFunctionsCreator(implicit withSchemaValidation: ValidationMode): - this: AstCreator => - - case class MethodAst(ast: Ast, methodNode: NewMethod, methodAst: Ast) - - private def handleRestInParameters( - elementNodeInfo: BabelNodeInfo, - paramNodeInfo: BabelNodeInfo, - paramName: String - ): Ast = - val ast = astForNodeWithFunctionReferenceAndCall(elementNodeInfo.json("argument")) - val defaultName = codeForNodes(ast.nodes.toSeq) - val restName = nameForBabelNodeInfo(paramNodeInfo, defaultName) - ast.root match - case Some(_: NewIdentifier) => - val keyNode = createFieldIdentifierNode( - restName, - elementNodeInfo.lineNumber, - elementNodeInfo.columnNumber + this: AstCreator => + + case class MethodAst(ast: Ast, methodNode: NewMethod, methodAst: Ast) + + private def handleRestInParameters( + elementNodeInfo: BabelNodeInfo, + paramNodeInfo: BabelNodeInfo, + paramName: String + ): Ast = + val ast = astForNodeWithFunctionReferenceAndCall(elementNodeInfo.json("argument")) + val defaultName = codeForNodes(ast.nodes.toSeq) + val restName = nameForBabelNodeInfo(paramNodeInfo, defaultName) + ast.root match + case Some(_: NewIdentifier) => + val keyNode = createFieldIdentifierNode( + restName, + elementNodeInfo.lineNumber, + elementNodeInfo.columnNumber + ) + val paramNode = identifierNode(elementNodeInfo, paramName) + val accessAst = + createFieldAccessCallAst( + paramNode, + keyNode, + elementNodeInfo.lineNumber, + elementNodeInfo.columnNumber + ) + createAssignmentCallAst( + ast, + accessAst, + s"$restName = ${codeOf(accessAst.nodes.head)}", + elementNodeInfo.lineNumber, + elementNodeInfo.columnNumber + ) + case _ => + val localParamNode = identifierNode(elementNodeInfo, restName) + createAssignmentCallAst( + Ast(localParamNode), + ast, + s"$restName = ${codeOf(ast.nodes.head)}", + elementNodeInfo.lineNumber, + elementNodeInfo.columnNumber + ) + end match + end handleRestInParameters + + private def handleParameters( + parameters: Seq[Value], + additionalBlockStatements: mutable.ArrayBuffer[Ast], + createLocals: Boolean = true + ): Seq[NewMethodParameterIn] = + withIndex(parameters) { case (param, index) => + val nodeInfo = createBabelNodeInfo(param) + val paramNode = nodeInfo.node match + case RestElement => + val paramName = nodeInfo.code.replace("...", "") + val tpe = typeFor(nodeInfo) + if createLocals then + val localNode = newLocalNode(paramName, tpe).order(0) + diffGraph.addEdge(localAstParentStack.head, localNode, EdgeTypes.AST) + parameterInNode( + nodeInfo, + paramName, + nodeInfo.code, + index, + true, + EvaluationStrategies.BY_VALUE, + Option(tpe) ) - val paramNode = identifierNode(elementNodeInfo, paramName) - val accessAst = - createFieldAccessCallAst( - paramNode, - keyNode, - elementNodeInfo.lineNumber, - elementNodeInfo.columnNumber - ) - createAssignmentCallAst( - ast, - accessAst, - s"$restName = ${codeOf(accessAst.nodes.head)}", - elementNodeInfo.lineNumber, - elementNodeInfo.columnNumber + case AssignmentPattern => + val lhsElement = nodeInfo.json("left") + val rhsElement = nodeInfo.json("right") + val lhsNodeInfo = createBabelNodeInfo(lhsElement) + lhsNodeInfo.node match + case ObjectPattern | ArrayPattern => + val paramName = + generateUnusedVariableName(usedVariableNames, s"param$index") + val param = parameterInNode( + nodeInfo, + paramName, + nodeInfo.code, + index, + isVariadic = false, + EvaluationStrategies.BY_VALUE + ) + scope.addVariable(paramName, param, MethodScope) + + val rhsAst = astForNodeWithFunctionReference(rhsElement) + additionalBlockStatements.addOne( + astForDeconstruction( + lhsNodeInfo, + rhsAst, + nodeInfo.code, + Option(paramName) + ) + ) + param + case _ => + additionalBlockStatements.addOne(convertParamWithDefault(nodeInfo)) + val tpe = typeFor(lhsNodeInfo) + parameterInNode( + lhsNodeInfo, + lhsNodeInfo.code, + nodeInfo.code, + index, + false, + EvaluationStrategies.BY_VALUE, + Option(tpe) + ) + end match + case ArrayPattern => + val paramName = generateUnusedVariableName(usedVariableNames, s"param$index") + val tpe = typeFor(nodeInfo) + val param = parameterInNode( + nodeInfo, + paramName, + nodeInfo.code, + index, + isVariadic = false, + EvaluationStrategies.BY_VALUE, + Option(tpe) ) - case _ => - val localParamNode = identifierNode(elementNodeInfo, restName) - createAssignmentCallAst( - Ast(localParamNode), - ast, - s"$restName = ${codeOf(ast.nodes.head)}", - elementNodeInfo.lineNumber, - elementNodeInfo.columnNumber + additionalBlockStatements.addAll(nodeInfo.json("elements").arr.toList.map { + case element if !element.isNull => + val elementNodeInfo = createBabelNodeInfo(element) + elementNodeInfo.node match + case Identifier => + val elemName = code(elementNodeInfo.json) + val tpe = typeFor(elementNodeInfo) + val localParamNode = identifierNode(elementNodeInfo, elemName) + localParamNode.typeFullName = tpe + + val localNode = newLocalNode(elemName, tpe).order(0) + diffGraph.addEdge( + localAstParentStack.head, + localNode, + EdgeTypes.AST + ) + scope.addVariable(elemName, localNode, MethodScope) + + val paramNode = identifierNode(elementNodeInfo, paramName) + scope.addVariableReference(paramName, paramNode) + + val keyNode = + createFieldIdentifierNode( + elemName, + elementNodeInfo.lineNumber, + elementNodeInfo.columnNumber + ) + val accessAst = + createFieldAccessCallAst( + paramNode, + keyNode, + elementNodeInfo.lineNumber, + elementNodeInfo.columnNumber + ) + createAssignmentCallAst( + Ast(localParamNode), + accessAst, + s"$elemName = ${codeOf(accessAst.nodes.head)}", + elementNodeInfo.lineNumber, + elementNodeInfo.columnNumber + ) + case RestElement => + handleRestInParameters(elementNodeInfo, nodeInfo, paramName) + case _ => astForNodeWithFunctionReference(elementNodeInfo.json) + end match + case _ => Ast() + }) + param + case ObjectPattern => + val paramName = generateUnusedVariableName(usedVariableNames, s"param$index") + // Handle de-structured parameters declared as `{ username: string; password: string; }` + val typeDecl = astForTypeAlias(nodeInfo) + val tpe = typeDecl.root.collect { case t: NewTypeDecl => t.fullName }.getOrElse( + typeFor(nodeInfo) + ) + val param = parameterInNode( + nodeInfo, + paramName, + nodeInfo.code, + index, + isVariadic = false, + EvaluationStrategies.BY_VALUE, + Option(tpe) ) - end match - end handleRestInParameters - - private def handleParameters( - parameters: Seq[Value], - additionalBlockStatements: mutable.ArrayBuffer[Ast], - createLocals: Boolean = true - ): Seq[NewMethodParameterIn] = - withIndex(parameters) { case (param, index) => - val nodeInfo = createBabelNodeInfo(param) - val paramNode = nodeInfo.node match - case RestElement => - val paramName = nodeInfo.code.replace("...", "") - val tpe = typeFor(nodeInfo) - if createLocals then - val localNode = newLocalNode(paramName, tpe).order(0) - diffGraph.addEdge(localAstParentStack.head, localNode, EdgeTypes.AST) + Ast.storeInDiffGraph(typeDecl, diffGraph) + scope.addVariable(paramName, param, MethodScope) + + additionalBlockStatements.addAll(nodeInfo.json("properties").arr.toList.map { + element => + val elementNodeInfo = createBabelNodeInfo(element) + elementNodeInfo.node match + case ObjectProperty => + val elemName = code(elementNodeInfo.json("key")) + val tpe = typeFor(elementNodeInfo) + val localParamNode = identifierNode(elementNodeInfo, elemName) + localParamNode.typeFullName = tpe + + val localNode = newLocalNode(elemName, tpe).order(0) + diffGraph.addEdge( + localAstParentStack.head, + localNode, + EdgeTypes.AST + ) + scope.addVariable(elemName, localNode, MethodScope) + + val paramNode = identifierNode(elementNodeInfo, paramName) + scope.addVariableReference(paramName, paramNode) + + val keyNode = + createFieldIdentifierNode( + elemName, + elementNodeInfo.lineNumber, + elementNodeInfo.columnNumber + ) + val accessAst = + createFieldAccessCallAst( + paramNode, + keyNode, + elementNodeInfo.lineNumber, + elementNodeInfo.columnNumber + ) + val assignmentCallAst = createAssignmentCallAst( + Ast(localParamNode), + accessAst, + s"$elemName = ${codeOf(accessAst.nodes.head)}", + elementNodeInfo.lineNumber, + elementNodeInfo.columnNumber + ) + // Handle identifiers referring to locals created by destructured parameters + assignmentCallAst.nodes + .collect { + case i: NewIdentifier if localNode.name == i.name => i + } + .map { i => assignmentCallAst.withRefEdge(i, localNode) } + .reduce(_.merge(_)) + case RestElement => + handleRestInParameters(elementNodeInfo, nodeInfo, paramName) + case _ => astForNodeWithFunctionReference(elementNodeInfo.json) + end match + }) + param + case Identifier => + // Handle types declared as `credentials: { username: string; password: string; }` + val tpe = + Try(createBabelNodeInfo(nodeInfo.json("typeAnnotation")("typeAnnotation"))) + .map(x => + x.node match + case TSTypeLiteral => + val typeDecl = astForTypeAlias(x) + Ast.storeInDiffGraph(typeDecl, diffGraph) + typeDecl.root.collect { case t: NewTypeDecl => + t.fullName + }.getOrElse(typeFor(nodeInfo)) + case _ => typeFor(nodeInfo) + ) + .getOrElse(typeFor(nodeInfo)) + + val name = nodeInfo.json("name").str + val node = parameterInNode( nodeInfo, - paramName, + name, nodeInfo.code, index, - true, + false, EvaluationStrategies.BY_VALUE, Option(tpe) ) - case AssignmentPattern => - val lhsElement = nodeInfo.json("left") - val rhsElement = nodeInfo.json("right") - val lhsNodeInfo = createBabelNodeInfo(lhsElement) - lhsNodeInfo.node match - case ObjectPattern | ArrayPattern => - val paramName = - generateUnusedVariableName(usedVariableNames, s"param$index") - val param = parameterInNode( - nodeInfo, - paramName, - nodeInfo.code, - index, - isVariadic = false, - EvaluationStrategies.BY_VALUE - ) - scope.addVariable(paramName, param, MethodScope) - - val rhsAst = astForNodeWithFunctionReference(rhsElement) - additionalBlockStatements.addOne( - astForDeconstruction( - lhsNodeInfo, - rhsAst, - nodeInfo.code, - Option(paramName) - ) - ) - param - case _ => - additionalBlockStatements.addOne(convertParamWithDefault(nodeInfo)) - val tpe = typeFor(lhsNodeInfo) - parameterInNode( - lhsNodeInfo, - lhsNodeInfo.code, - nodeInfo.code, - index, - false, - EvaluationStrategies.BY_VALUE, - Option(tpe) - ) - end match - case ArrayPattern => - val paramName = generateUnusedVariableName(usedVariableNames, s"param$index") - val tpe = typeFor(nodeInfo) - val param = parameterInNode( + scope.addVariable(name, node, MethodScope) + node + case TSParameterProperty => + val unpackedParam = createBabelNodeInfo(nodeInfo.json("parameter")) + val tpe = typeFor(unpackedParam) + + val name = unpackedParam.node match + case AssignmentPattern => + createBabelNodeInfo(unpackedParam.json("left")).code + case _ => unpackedParam.json("name").str + val node = + parameterInNode( nodeInfo, - paramName, + name, nodeInfo.code, index, - isVariadic = false, + false, EvaluationStrategies.BY_VALUE, Option(tpe) ) - additionalBlockStatements.addAll(nodeInfo.json("elements").arr.toList.map { - case element if !element.isNull => - val elementNodeInfo = createBabelNodeInfo(element) - elementNodeInfo.node match - case Identifier => - val elemName = code(elementNodeInfo.json) - val tpe = typeFor(elementNodeInfo) - val localParamNode = identifierNode(elementNodeInfo, elemName) - localParamNode.typeFullName = tpe - - val localNode = newLocalNode(elemName, tpe).order(0) - diffGraph.addEdge( - localAstParentStack.head, - localNode, - EdgeTypes.AST - ) - scope.addVariable(elemName, localNode, MethodScope) - - val paramNode = identifierNode(elementNodeInfo, paramName) - scope.addVariableReference(paramName, paramNode) - - val keyNode = - createFieldIdentifierNode( - elemName, - elementNodeInfo.lineNumber, - elementNodeInfo.columnNumber - ) - val accessAst = - createFieldAccessCallAst( - paramNode, - keyNode, - elementNodeInfo.lineNumber, - elementNodeInfo.columnNumber - ) - createAssignmentCallAst( - Ast(localParamNode), - accessAst, - s"$elemName = ${codeOf(accessAst.nodes.head)}", - elementNodeInfo.lineNumber, - elementNodeInfo.columnNumber - ) - case RestElement => - handleRestInParameters(elementNodeInfo, nodeInfo, paramName) - case _ => astForNodeWithFunctionReference(elementNodeInfo.json) - end match - case _ => Ast() - }) - param - case ObjectPattern => - val paramName = generateUnusedVariableName(usedVariableNames, s"param$index") - // Handle de-structured parameters declared as `{ username: string; password: string; }` - val typeDecl = astForTypeAlias(nodeInfo) - val tpe = typeDecl.root.collect { case t: NewTypeDecl => t.fullName }.getOrElse( - typeFor(nodeInfo) - ) - val param = parameterInNode( + scope.addVariable(name, node, MethodScope) + node + case _ => + val tpe = typeFor(nodeInfo) + val node = + parameterInNode( nodeInfo, - paramName, + nodeInfo.code, nodeInfo.code, index, isVariadic = false, EvaluationStrategies.BY_VALUE, Option(tpe) ) - Ast.storeInDiffGraph(typeDecl, diffGraph) - scope.addVariable(paramName, param, MethodScope) - - additionalBlockStatements.addAll(nodeInfo.json("properties").arr.toList.map { - element => - val elementNodeInfo = createBabelNodeInfo(element) - elementNodeInfo.node match - case ObjectProperty => - val elemName = code(elementNodeInfo.json("key")) - val tpe = typeFor(elementNodeInfo) - val localParamNode = identifierNode(elementNodeInfo, elemName) - localParamNode.typeFullName = tpe - - val localNode = newLocalNode(elemName, tpe).order(0) - diffGraph.addEdge( - localAstParentStack.head, - localNode, - EdgeTypes.AST - ) - scope.addVariable(elemName, localNode, MethodScope) - - val paramNode = identifierNode(elementNodeInfo, paramName) - scope.addVariableReference(paramName, paramNode) - - val keyNode = - createFieldIdentifierNode( - elemName, - elementNodeInfo.lineNumber, - elementNodeInfo.columnNumber - ) - val accessAst = - createFieldAccessCallAst( - paramNode, - keyNode, - elementNodeInfo.lineNumber, - elementNodeInfo.columnNumber - ) - val assignmentCallAst = createAssignmentCallAst( - Ast(localParamNode), - accessAst, - s"$elemName = ${codeOf(accessAst.nodes.head)}", - elementNodeInfo.lineNumber, - elementNodeInfo.columnNumber - ) - // Handle identifiers referring to locals created by destructured parameters - assignmentCallAst.nodes - .collect { - case i: NewIdentifier if localNode.name == i.name => i - } - .map { i => assignmentCallAst.withRefEdge(i, localNode) } - .reduce(_.merge(_)) - case RestElement => - handleRestInParameters(elementNodeInfo, nodeInfo, paramName) - case _ => astForNodeWithFunctionReference(elementNodeInfo.json) - end match - }) - param - case Identifier => - // Handle types declared as `credentials: { username: string; password: string; }` - val tpe = - Try(createBabelNodeInfo(nodeInfo.json("typeAnnotation")("typeAnnotation"))) - .map(x => - x.node match - case TSTypeLiteral => - val typeDecl = astForTypeAlias(x) - Ast.storeInDiffGraph(typeDecl, diffGraph) - typeDecl.root.collect { case t: NewTypeDecl => - t.fullName - }.getOrElse(typeFor(nodeInfo)) - case _ => typeFor(nodeInfo) - ) - .getOrElse(typeFor(nodeInfo)) - - val name = nodeInfo.json("name").str - val node = - parameterInNode( - nodeInfo, - name, - nodeInfo.code, - index, - false, - EvaluationStrategies.BY_VALUE, - Option(tpe) - ) - scope.addVariable(name, node, MethodScope) - node - case TSParameterProperty => - val unpackedParam = createBabelNodeInfo(nodeInfo.json("parameter")) - val tpe = typeFor(unpackedParam) - - val name = unpackedParam.node match - case AssignmentPattern => - createBabelNodeInfo(unpackedParam.json("left")).code - case _ => unpackedParam.json("name").str - val node = - parameterInNode( - nodeInfo, - name, - nodeInfo.code, - index, - false, - EvaluationStrategies.BY_VALUE, - Option(tpe) - ) - scope.addVariable(name, node, MethodScope) - node - case _ => - val tpe = typeFor(nodeInfo) - val node = - parameterInNode( - nodeInfo, - nodeInfo.code, - nodeInfo.code, - index, - isVariadic = false, - EvaluationStrategies.BY_VALUE, - Option(tpe) - ) - scope.addVariable(nodeInfo.code, node, MethodScope) - node - val decoratorAsts = astsForDecorators(nodeInfo) - decoratorAsts.foreach { decoratorAst => - Ast.storeInDiffGraph(decoratorAst, diffGraph) - decoratorAst.root.foreach(diffGraph.addEdge(paramNode, _, EdgeTypes.AST)) - } - paramNode - } - - private def convertParamWithDefault(element: BabelNodeInfo): Ast = - val lhsElement = element.json("left") - val rhsElement = element.json("right") - - val rhsAst = astForNodeWithFunctionReference(rhsElement) - - val lhsAst = astForNode(lhsElement) - - val testAst = - val keyNode = identifierNode(element, codeOf(lhsAst.nodes.head)) - val voidCallNode = - callNode(element, "void 0", ".void", DispatchTypes.STATIC_DISPATCH) - val equalsCallAst = createEqualsCallAst( - Ast(keyNode), - Ast(voidCallNode), - element.lineNumber, - element.columnNumber - ) - equalsCallAst - val falseNode = identifierNode(element, codeOf(lhsAst.nodes.head)) - val ternaryNodeAst = - createTernaryCallAst( - testAst, - rhsAst, - Ast(falseNode), - element.lineNumber, - element.columnNumber - ) - createAssignmentCallAst( - lhsAst, - ternaryNodeAst, - s"${codeOf(lhsAst.nodes.head)} = ${codeOf(ternaryNodeAst.nodes.head)}", + scope.addVariable(nodeInfo.code, node, MethodScope) + node + val decoratorAsts = astsForDecorators(nodeInfo) + decoratorAsts.foreach { decoratorAst => + Ast.storeInDiffGraph(decoratorAst, diffGraph) + decoratorAst.root.foreach(diffGraph.addEdge(paramNode, _, EdgeTypes.AST)) + } + paramNode + } + + private def convertParamWithDefault(element: BabelNodeInfo): Ast = + val lhsElement = element.json("left") + val rhsElement = element.json("right") + + val rhsAst = astForNodeWithFunctionReference(rhsElement) + + val lhsAst = astForNode(lhsElement) + + val testAst = + val keyNode = identifierNode(element, codeOf(lhsAst.nodes.head)) + val voidCallNode = + callNode(element, "void 0", ".void", DispatchTypes.STATIC_DISPATCH) + val equalsCallAst = createEqualsCallAst( + Ast(keyNode), + Ast(voidCallNode), + element.lineNumber, + element.columnNumber + ) + equalsCallAst + val falseNode = identifierNode(element, codeOf(lhsAst.nodes.head)) + val ternaryNodeAst = + createTernaryCallAst( + testAst, + rhsAst, + Ast(falseNode), element.lineNumber, element.columnNumber ) - end convertParamWithDefault - - private def getParentTypeDecl: NewTypeDecl = - methodAstParentStack.collectFirst { case n: NewTypeDecl => n }.getOrElse(rootTypeDecl.head) - - protected def astForTSDeclareFunction(func: BabelNodeInfo): Ast = - val functionNode = createMethodDefinitionNode(func) - val bindingNode = newBindingNode("", "", "") - diffGraph.addEdge(getParentTypeDecl, bindingNode, EdgeTypes.BINDS) - diffGraph.addEdge(bindingNode, functionNode, EdgeTypes.REF) - addModifier(functionNode, func.json) - Ast(functionNode) - - protected def createMethodDefinitionNode( - func: BabelNodeInfo, - methodBlockContent: List[Ast] = List.empty - ): NewMethod = - val (methodName, methodFullName) = calcMethodNameAndFullName(func) - val methodNode_ = - methodNode(func, methodName, func.code, methodFullName, None, parserResult.filename) - val virtualModifierNode = NewModifier().modifierType(ModifierTypes.VIRTUAL) - methodAstParentStack.push(methodNode_) - - val thisNode = - parameterInNode(func, "this", "this", 0, false, EvaluationStrategies.BY_VALUE) - .dynamicTypeHintFullName(typeHintForThisExpression()) - scope.addVariable("this", thisNode, MethodScope) - - val paramNodes = if hasKey(func.json, "parameters") then - handleParameters( - func.json("parameters").arr.toSeq, - mutable.ArrayBuffer.empty[Ast], - createLocals = false - ) - else - handleParameters( - func.json("params").arr.toSeq, - mutable.ArrayBuffer.empty[Ast], - createLocals = false - ) - - val methodReturnNode = createMethodReturnNode(func) - - methodAstParentStack.pop() - - val functionTypeAndTypeDeclAst = - createFunctionTypeAndTypeDeclAst( - func, - methodNode_, - methodAstParentStack.head, - methodName, - methodFullName, - parserResult.filename - ) - - val mAst = if methodBlockContent.isEmpty then - methodStubAst( - methodNode_, - thisNode +: paramNodes, - methodReturnNode, - List(virtualModifierNode) - ) - else - setArgumentIndices(methodBlockContent) - val bodyAst = blockAst(NewBlock(), methodBlockContent) - methodAstWithAnnotations( - methodNode_, - (thisNode +: paramNodes).map(Ast(_)), - bodyAst, - methodReturnNode, - annotations = astsForDecorators(func) - ) - - Ast.storeInDiffGraph(mAst, diffGraph) - Ast.storeInDiffGraph(functionTypeAndTypeDeclAst, diffGraph) - diffGraph.addEdge(methodAstParentStack.head, methodNode_, EdgeTypes.AST) - - methodNode_ - end createMethodDefinitionNode - - protected def createMethodAstAndNode( - func: BabelNodeInfo, - shouldCreateFunctionReference: Boolean = false, - shouldCreateAssignmentCall: Boolean = false, - methodBlockContent: List[Ast] = List.empty - ): MethodAst = - val (methodName, methodFullName) = calcMethodNameAndFullName(func) - val methodRefNode_ = if !shouldCreateFunctionReference then - None - else Option(methodRefNode(func, methodName, methodFullName, methodFullName)) - - val callAst = if shouldCreateAssignmentCall && shouldCreateFunctionReference then - val idNode = identifierNode(func, methodName) - val idLocal = newLocalNode(methodName, methodFullName).order(0) - diffGraph.addEdge(localAstParentStack.head, idLocal, EdgeTypes.AST) - scope.addVariable(methodName, idLocal, BlockScope) - scope.addVariableReference(methodName, idNode) - val code = s"function $methodName = ${func.code}" - val assignment = createAssignmentCallAst( - idNode, - methodRefNode_.get, - code, - func.lineNumber, - func.columnNumber - ) - assignment + createAssignmentCallAst( + lhsAst, + ternaryNodeAst, + s"${codeOf(lhsAst.nodes.head)} = ${codeOf(ternaryNodeAst.nodes.head)}", + element.lineNumber, + element.columnNumber + ) + end convertParamWithDefault + + private def getParentTypeDecl: NewTypeDecl = + methodAstParentStack.collectFirst { case n: NewTypeDecl => n }.getOrElse(rootTypeDecl.head) + + protected def astForTSDeclareFunction(func: BabelNodeInfo): Ast = + val functionNode = createMethodDefinitionNode(func) + val bindingNode = newBindingNode("", "", "") + diffGraph.addEdge(getParentTypeDecl, bindingNode, EdgeTypes.BINDS) + diffGraph.addEdge(bindingNode, functionNode, EdgeTypes.REF) + addModifier(functionNode, func.json) + Ast(functionNode) + + protected def createMethodDefinitionNode( + func: BabelNodeInfo, + methodBlockContent: List[Ast] = List.empty + ): NewMethod = + val (methodName, methodFullName) = calcMethodNameAndFullName(func) + val methodNode_ = + methodNode(func, methodName, func.code, methodFullName, None, parserResult.filename) + val virtualModifierNode = NewModifier().modifierType(ModifierTypes.VIRTUAL) + methodAstParentStack.push(methodNode_) + + val thisNode = + parameterInNode(func, "this", "this", 0, false, EvaluationStrategies.BY_VALUE) + .dynamicTypeHintFullName(typeHintForThisExpression()) + scope.addVariable("this", thisNode, MethodScope) + + val paramNodes = if hasKey(func.json, "parameters") then + handleParameters( + func.json("parameters").arr.toSeq, + mutable.ArrayBuffer.empty[Ast], + createLocals = false + ) + else + handleParameters( + func.json("params").arr.toSeq, + mutable.ArrayBuffer.empty[Ast], + createLocals = false + ) + + val methodReturnNode = createMethodReturnNode(func) + + methodAstParentStack.pop() + + val functionTypeAndTypeDeclAst = + createFunctionTypeAndTypeDeclAst( + func, + methodNode_, + methodAstParentStack.head, + methodName, + methodFullName, + parserResult.filename + ) + + val mAst = if methodBlockContent.isEmpty then + methodStubAst( + methodNode_, + thisNode +: paramNodes, + methodReturnNode, + List(virtualModifierNode) + ) + else + setArgumentIndices(methodBlockContent) + val bodyAst = blockAst(NewBlock(), methodBlockContent) + methodAstWithAnnotations( + methodNode_, + (thisNode +: paramNodes).map(Ast(_)), + bodyAst, + methodReturnNode, + annotations = astsForDecorators(func) + ) + + Ast.storeInDiffGraph(mAst, diffGraph) + Ast.storeInDiffGraph(functionTypeAndTypeDeclAst, diffGraph) + diffGraph.addEdge(methodAstParentStack.head, methodNode_, EdgeTypes.AST) + + methodNode_ + end createMethodDefinitionNode + + protected def createMethodAstAndNode( + func: BabelNodeInfo, + shouldCreateFunctionReference: Boolean = false, + shouldCreateAssignmentCall: Boolean = false, + methodBlockContent: List[Ast] = List.empty + ): MethodAst = + val (methodName, methodFullName) = calcMethodNameAndFullName(func) + val methodRefNode_ = if !shouldCreateFunctionReference then + None + else Option(methodRefNode(func, methodName, methodFullName, methodFullName)) + + val callAst = if shouldCreateAssignmentCall && shouldCreateFunctionReference then + val idNode = identifierNode(func, methodName) + val idLocal = newLocalNode(methodName, methodFullName).order(0) + diffGraph.addEdge(localAstParentStack.head, idLocal, EdgeTypes.AST) + scope.addVariable(methodName, idLocal, BlockScope) + scope.addVariableReference(methodName, idNode) + val code = s"function $methodName = ${func.code}" + val assignment = createAssignmentCallAst( + idNode, + methodRefNode_.get, + code, + func.lineNumber, + func.columnNumber + ) + assignment + else + Ast() + + val methodNode_ = + methodNode(func, methodName, func.code, methodFullName, None, parserResult.filename) + val virtualModifierNode = NewModifier().modifierType(ModifierTypes.VIRTUAL) + + methodAstParentStack.push(methodNode_) + + val bodyJson = func.json("body") + val bodyNodeInfo = createBabelNodeInfo(bodyJson) + val blockNode = createBlockNode(bodyNodeInfo) + val additionalBlockStatements = mutable.ArrayBuffer.empty[Ast] + + val capturingRefNode = + if shouldCreateFunctionReference then + methodRefNode_ else - Ast() - - val methodNode_ = - methodNode(func, methodName, func.code, methodFullName, None, parserResult.filename) - val virtualModifierNode = NewModifier().modifierType(ModifierTypes.VIRTUAL) - - methodAstParentStack.push(methodNode_) - - val bodyJson = func.json("body") - val bodyNodeInfo = createBabelNodeInfo(bodyJson) - val blockNode = createBlockNode(bodyNodeInfo) - val additionalBlockStatements = mutable.ArrayBuffer.empty[Ast] - - val capturingRefNode = - if shouldCreateFunctionReference then - methodRefNode_ - else - typeRefIdStack.headOption - scope.pushNewMethodScope(methodFullName, methodName, blockNode, capturingRefNode) - localAstParentStack.push(blockNode) - - val thisNode = - parameterInNode(func, "this", "this", 0, false, EvaluationStrategies.BY_VALUE) - .dynamicTypeHintFullName(typeHintForThisExpression()) - scope.addVariable("this", thisNode, MethodScope) - - val paramNodes = handleParameters(func.json("params").arr.toSeq, additionalBlockStatements) - - val bodyStmtAsts = func.node match - case ArrowFunctionExpression => - bodyNodeInfo.node match - case BlockStatement => - // when body contains more than one statement, use bodyJson("body")) to avoid double Block node - createBlockStatementAsts(bodyJson("body")) - case _ => - // when body is just one expression like const foo = () => 42, generate a Return node - val retCode = bodyNodeInfo.code.stripSuffix(";") - returnAst( - returnNode(bodyNodeInfo, retCode), - List(astForNodeWithFunctionReference(bodyJson)) - ) :: Nil - case _ => createBlockStatementAsts(bodyJson("body")) - val methodBlockChildren = - methodBlockContent ++ additionalBlockStatements.toList ++ bodyStmtAsts - setArgumentIndices(methodBlockChildren) - - val methodReturnNode = createMethodReturnNode(func) - - localAstParentStack.pop() - scope.popScope() - methodAstParentStack.pop() - - val functionTypeAndTypeDeclAst = - createFunctionTypeAndTypeDeclAst( - func, - methodNode_, - methodAstParentStack.head, - methodName, - methodFullName, - parserResult.filename - ) - - val mAst = - methodAstWithAnnotations( - methodNode_, - (thisNode +: paramNodes).map(Ast(_)), - blockAst(blockNode, methodBlockChildren), - methodReturnNode, - List(virtualModifierNode), - astsForDecorators(func) - ) - Ast.storeInDiffGraph(mAst, diffGraph) - Ast.storeInDiffGraph(functionTypeAndTypeDeclAst, diffGraph) - diffGraph.addEdge(methodAstParentStack.head, methodNode_, EdgeTypes.AST) - - methodRefNode_ match - case Some(ref) if callAst.nodes.isEmpty => - MethodAst(Ast(ref), methodNode_, mAst) + typeRefIdStack.headOption + scope.pushNewMethodScope(methodFullName, methodName, blockNode, capturingRefNode) + localAstParentStack.push(blockNode) + + val thisNode = + parameterInNode(func, "this", "this", 0, false, EvaluationStrategies.BY_VALUE) + .dynamicTypeHintFullName(typeHintForThisExpression()) + scope.addVariable("this", thisNode, MethodScope) + + val paramNodes = handleParameters(func.json("params").arr.toSeq, additionalBlockStatements) + + val bodyStmtAsts = func.node match + case ArrowFunctionExpression => + bodyNodeInfo.node match + case BlockStatement => + // when body contains more than one statement, use bodyJson("body")) to avoid double Block node + createBlockStatementAsts(bodyJson("body")) case _ => - MethodAst(callAst, methodNode_, mAst) - end createMethodAstAndNode - - protected def astForFunctionDeclaration( - func: BabelNodeInfo, - shouldCreateFunctionReference: Boolean = false, - shouldCreateAssignmentCall: Boolean = false - ): Ast = - createMethodAstAndNode(func, shouldCreateFunctionReference, shouldCreateAssignmentCall).ast + // when body is just one expression like const foo = () => 42, generate a Return node + val retCode = bodyNodeInfo.code.stripSuffix(";") + returnAst( + returnNode(bodyNodeInfo, retCode), + List(astForNodeWithFunctionReference(bodyJson)) + ) :: Nil + case _ => createBlockStatementAsts(bodyJson("body")) + val methodBlockChildren = + methodBlockContent ++ additionalBlockStatements.toList ++ bodyStmtAsts + setArgumentIndices(methodBlockChildren) + + val methodReturnNode = createMethodReturnNode(func) + + localAstParentStack.pop() + scope.popScope() + methodAstParentStack.pop() + + val functionTypeAndTypeDeclAst = + createFunctionTypeAndTypeDeclAst( + func, + methodNode_, + methodAstParentStack.head, + methodName, + methodFullName, + parserResult.filename + ) + + val mAst = + methodAstWithAnnotations( + methodNode_, + (thisNode +: paramNodes).map(Ast(_)), + blockAst(blockNode, methodBlockChildren), + methodReturnNode, + List(virtualModifierNode), + astsForDecorators(func) + ) + Ast.storeInDiffGraph(mAst, diffGraph) + Ast.storeInDiffGraph(functionTypeAndTypeDeclAst, diffGraph) + diffGraph.addEdge(methodAstParentStack.head, methodNode_, EdgeTypes.AST) + + methodRefNode_ match + case Some(ref) if callAst.nodes.isEmpty => + MethodAst(Ast(ref), methodNode_, mAst) + case _ => + MethodAst(callAst, methodNode_, mAst) + end createMethodAstAndNode + + protected def astForFunctionDeclaration( + func: BabelNodeInfo, + shouldCreateFunctionReference: Boolean = false, + shouldCreateAssignmentCall: Boolean = false + ): Ast = + createMethodAstAndNode(func, shouldCreateFunctionReference, shouldCreateAssignmentCall).ast end AstForFunctionsCreator diff --git a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/astcreation/AstForPrimitivesCreator.scala b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/astcreation/AstForPrimitivesCreator.scala index 3edb9e8b..81a13739 100644 --- a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/astcreation/AstForPrimitivesCreator.scala +++ b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/astcreation/AstForPrimitivesCreator.scala @@ -6,101 +6,101 @@ import io.appthreat.x2cpg.{Ast, ValidationMode} import io.shiftleft.codepropertygraph.generated.{DispatchTypes, Operators} trait AstForPrimitivesCreator(implicit withSchemaValidation: ValidationMode): - this: AstCreator => - - protected def astForIdentifier(ident: BabelNodeInfo, typeFullName: Option[String] = None): Ast = - val name = ident.json("name").str - val identNode = identifierNode(ident, name) - val tpe = typeFullName match - case Some(Defines.Any) => typeFor(ident) - case Some(otherType) => otherType - case None => typeFor(ident) - identNode.typeFullName = tpe - scope.addVariableReference(name, identNode) - Ast(identNode) - - protected def astForSuperKeyword(superKeyword: BabelNodeInfo): Ast = - Ast(identifierNode(superKeyword, "super")) - - protected def astForImportKeyword(importKeyword: BabelNodeInfo): Ast = - Ast(identifierNode(importKeyword, "import")) - - protected def astForNullLiteral(nullLiteral: BabelNodeInfo): Ast = - Ast(literalNode(nullLiteral, nullLiteral.code, Option(Defines.Null))) - - protected def astForStringLiteral(stringLiteral: BabelNodeInfo): Ast = - val code = s"\"${stringLiteral.json("value").str}\"" - Ast(literalNode(stringLiteral, code, Option(Defines.String))) - - protected def astForPrivateName(privateName: BabelNodeInfo): Ast = - astForIdentifier(createBabelNodeInfo(privateName.json("id"))) - - protected def astForSpreadOrRestElement( - spreadElement: BabelNodeInfo, - arg1Ast: Option[Ast] = None - ): Ast = - val ast = astForNodeWithFunctionReference(spreadElement.json("argument")) - val callNode_ = - callNode( - spreadElement, - spreadElement.code, - ".spread", - DispatchTypes.STATIC_DISPATCH - ) - callAst(callNode_, arg1Ast.toList :+ ast) - - protected def astForTemplateElement(templateElement: BabelNodeInfo): Ast = - Ast(literalNode( - templateElement, - s"\"${templateElement.json("value")("raw").str}\"", - Option(Defines.String) - )) - - protected def astForRegExpLiteral(regExpLiteral: BabelNodeInfo): Ast = - Ast(literalNode(regExpLiteral, regExpLiteral.code, Option(Defines.String))) - - protected def astForRegexLiteral(regexLiteral: BabelNodeInfo): Ast = - Ast(literalNode(regexLiteral, regexLiteral.code, Option(Defines.String))) - - protected def astForNumberLiteral(numberLiteral: BabelNodeInfo): Ast = - Ast(literalNode(numberLiteral, numberLiteral.code, Option(Defines.Number))) - - protected def astForNumericLiteral(numericLiteral: BabelNodeInfo): Ast = - Ast(literalNode(numericLiteral, numericLiteral.code, Option(Defines.Number))) - - protected def astForDecimalLiteral(decimalLiteral: BabelNodeInfo): Ast = - Ast(literalNode(decimalLiteral, decimalLiteral.code, Option(Defines.Number))) - - protected def astForBigIntLiteral(bigIntLiteral: BabelNodeInfo): Ast = - Ast(literalNode(bigIntLiteral, bigIntLiteral.code, Option(Defines.Number))) - - protected def astForBooleanLiteral(booleanLiteral: BabelNodeInfo): Ast = - Ast(literalNode(booleanLiteral, booleanLiteral.code, Option(Defines.Boolean))) - - protected def astForTemplateLiteral(templateLiteral: BabelNodeInfo): Ast = - val expressions = templateLiteral.json("expressions").arr.toList - val quasis = templateLiteral.json("quasis").arr.toList.filterNot(_("tail").bool) - val quasisTail = templateLiteral.json("quasis").arr.toList.filter(_("tail").bool).head - - if expressions.isEmpty && quasis.isEmpty then - astForTemplateElement(createBabelNodeInfo(quasisTail)) - else - val callName = Operators.formatString - val argsCodes = expressions.zip(quasis).flatMap { case (expression, quasi) => - List(s"\"${quasi("value")("raw").str}\"", code(expression)) - } - val callCode = - s"$callName${(argsCodes :+ s"\"${quasisTail("value")("raw").str}\"").mkString("(", ", ", ")")}" - val templateCall = - callNode(templateLiteral, callCode, callName, DispatchTypes.STATIC_DISPATCH) - val argumentAsts = expressions.zip(quasis).flatMap { case (expression, quasi) => - List( - astForNodeWithFunctionReference(quasi), - astForNodeWithFunctionReference(expression) - ) - } - val argAsts = argumentAsts :+ astForNodeWithFunctionReference(quasisTail) - callAst(templateCall, argAsts) - end if - end astForTemplateLiteral + this: AstCreator => + + protected def astForIdentifier(ident: BabelNodeInfo, typeFullName: Option[String] = None): Ast = + val name = ident.json("name").str + val identNode = identifierNode(ident, name) + val tpe = typeFullName match + case Some(Defines.Any) => typeFor(ident) + case Some(otherType) => otherType + case None => typeFor(ident) + identNode.typeFullName = tpe + scope.addVariableReference(name, identNode) + Ast(identNode) + + protected def astForSuperKeyword(superKeyword: BabelNodeInfo): Ast = + Ast(identifierNode(superKeyword, "super")) + + protected def astForImportKeyword(importKeyword: BabelNodeInfo): Ast = + Ast(identifierNode(importKeyword, "import")) + + protected def astForNullLiteral(nullLiteral: BabelNodeInfo): Ast = + Ast(literalNode(nullLiteral, nullLiteral.code, Option(Defines.Null))) + + protected def astForStringLiteral(stringLiteral: BabelNodeInfo): Ast = + val code = s"\"${stringLiteral.json("value").str}\"" + Ast(literalNode(stringLiteral, code, Option(Defines.String))) + + protected def astForPrivateName(privateName: BabelNodeInfo): Ast = + astForIdentifier(createBabelNodeInfo(privateName.json("id"))) + + protected def astForSpreadOrRestElement( + spreadElement: BabelNodeInfo, + arg1Ast: Option[Ast] = None + ): Ast = + val ast = astForNodeWithFunctionReference(spreadElement.json("argument")) + val callNode_ = + callNode( + spreadElement, + spreadElement.code, + ".spread", + DispatchTypes.STATIC_DISPATCH + ) + callAst(callNode_, arg1Ast.toList :+ ast) + + protected def astForTemplateElement(templateElement: BabelNodeInfo): Ast = + Ast(literalNode( + templateElement, + s"\"${templateElement.json("value")("raw").str}\"", + Option(Defines.String) + )) + + protected def astForRegExpLiteral(regExpLiteral: BabelNodeInfo): Ast = + Ast(literalNode(regExpLiteral, regExpLiteral.code, Option(Defines.String))) + + protected def astForRegexLiteral(regexLiteral: BabelNodeInfo): Ast = + Ast(literalNode(regexLiteral, regexLiteral.code, Option(Defines.String))) + + protected def astForNumberLiteral(numberLiteral: BabelNodeInfo): Ast = + Ast(literalNode(numberLiteral, numberLiteral.code, Option(Defines.Number))) + + protected def astForNumericLiteral(numericLiteral: BabelNodeInfo): Ast = + Ast(literalNode(numericLiteral, numericLiteral.code, Option(Defines.Number))) + + protected def astForDecimalLiteral(decimalLiteral: BabelNodeInfo): Ast = + Ast(literalNode(decimalLiteral, decimalLiteral.code, Option(Defines.Number))) + + protected def astForBigIntLiteral(bigIntLiteral: BabelNodeInfo): Ast = + Ast(literalNode(bigIntLiteral, bigIntLiteral.code, Option(Defines.Number))) + + protected def astForBooleanLiteral(booleanLiteral: BabelNodeInfo): Ast = + Ast(literalNode(booleanLiteral, booleanLiteral.code, Option(Defines.Boolean))) + + protected def astForTemplateLiteral(templateLiteral: BabelNodeInfo): Ast = + val expressions = templateLiteral.json("expressions").arr.toList + val quasis = templateLiteral.json("quasis").arr.toList.filterNot(_("tail").bool) + val quasisTail = templateLiteral.json("quasis").arr.toList.filter(_("tail").bool).head + + if expressions.isEmpty && quasis.isEmpty then + astForTemplateElement(createBabelNodeInfo(quasisTail)) + else + val callName = Operators.formatString + val argsCodes = expressions.zip(quasis).flatMap { case (expression, quasi) => + List(s"\"${quasi("value")("raw").str}\"", code(expression)) + } + val callCode = + s"$callName${(argsCodes :+ s"\"${quasisTail("value")("raw").str}\"").mkString("(", ", ", ")")}" + val templateCall = + callNode(templateLiteral, callCode, callName, DispatchTypes.STATIC_DISPATCH) + val argumentAsts = expressions.zip(quasis).flatMap { case (expression, quasi) => + List( + astForNodeWithFunctionReference(quasi), + astForNodeWithFunctionReference(expression) + ) + } + val argAsts = argumentAsts :+ astForNodeWithFunctionReference(quasisTail) + callAst(templateCall, argAsts) + end if + end astForTemplateLiteral end AstForPrimitivesCreator diff --git a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/astcreation/AstForStatementsCreator.scala b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/astcreation/AstForStatementsCreator.scala index 6c403305..f957dc96 100644 --- a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/astcreation/AstForStatementsCreator.scala +++ b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/astcreation/AstForStatementsCreator.scala @@ -16,799 +16,1000 @@ import ujson.Obj import ujson.Value trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode): - this: AstCreator => - - /** Sort all block statements with the following result: - * - all function declarations go first - * - all type aliases that are not plain type references go last - * - all remaining type aliases go before that - * - all remaining statements go second - * - * We do this to get TypeDecls created at the right spot so we can make use of them for the - * type aliases. - */ - private def sortBlockStatements(blockStatements: List[BabelNodeInfo]): List[BabelNodeInfo] = - blockStatements.sortBy { nodeInfo => - nodeInfo.node match - case ImportDeclaration => 0 - case FunctionDeclaration => 1 - case DeclareTypeAlias if isPlainTypeAlias(nodeInfo) => 4 - case TypeAlias if isPlainTypeAlias(nodeInfo) => 4 - case TSTypeAliasDeclaration if isPlainTypeAlias(nodeInfo) => 4 - case DeclareTypeAlias => 3 - case TypeAlias => 3 - case TSTypeAliasDeclaration => 3 - case _ => 2 + this: AstCreator => + + /** Sort all block statements with the following result: + * - all function declarations go first + * - all type aliases that are not plain type references go last + * - all remaining type aliases go before that + * - all remaining statements go second + * + * We do this to get TypeDecls created at the right spot so we can make use of them for the type + * aliases. + */ + private def sortBlockStatements(blockStatements: List[BabelNodeInfo]): List[BabelNodeInfo] = + blockStatements.sortBy { nodeInfo => + nodeInfo.node match + case ImportDeclaration => 0 + case FunctionDeclaration => 1 + case DeclareTypeAlias if isPlainTypeAlias(nodeInfo) => 4 + case TypeAlias if isPlainTypeAlias(nodeInfo) => 4 + case TSTypeAliasDeclaration if isPlainTypeAlias(nodeInfo) => 4 + case DeclareTypeAlias => 3 + case TypeAlias => 3 + case TSTypeAliasDeclaration => 3 + case _ => 2 + } + + protected def createBlockStatementAsts(json: Value): List[Ast] = + val blockStmts = sortBlockStatements(json.arr.map(createBabelNodeInfo).toList) + val blockAsts = blockStmts.map(stmt => astForNodeWithFunctionReferenceAndCall(stmt.json)) + setArgumentIndices(blockAsts) + blockAsts + + protected def astForWithStatement(withStatement: BabelNodeInfo): Ast = + val blockNode = createBlockNode(withStatement) + scope.pushNewBlockScope(blockNode) + localAstParentStack.push(blockNode) + val objectAst = astForNodeWithFunctionReferenceAndCall(withStatement.json("object")) + val bodyNodeInfo = createBabelNodeInfo(withStatement.json("body")) + val bodyAsts = bodyNodeInfo.node match + case BlockStatement => createBlockStatementAsts(bodyNodeInfo.json("body")) + case _ => List(astForNodeWithFunctionReferenceAndCall(bodyNodeInfo.json)) + val blockStatementAsts = objectAst +: bodyAsts + setArgumentIndices(blockStatementAsts) + localAstParentStack.pop() + scope.popScope() + blockAst(blockNode, blockStatementAsts) + + protected def astForBlockStatement(block: BabelNodeInfo): Ast = + val blockNode = createBlockNode(block) + scope.pushNewBlockScope(blockNode) + localAstParentStack.push(blockNode) + val blockStatementAsts = createBlockStatementAsts(block.json("body")) + setArgumentIndices(blockStatementAsts) + localAstParentStack.pop() + scope.popScope() + blockAst(blockNode, blockStatementAsts) + + protected def astForReturnStatement(ret: BabelNodeInfo): Ast = + val retCode = ret.code.stripSuffix(";") + val retNode = returnNode(ret, retCode) + safeObj(ret.json, "argument") + .map { argument => + val argAst = astForNodeWithFunctionReference(Obj(argument)) + returnAst(retNode, List(argAst)) } + .getOrElse(Ast(retNode)) - protected def createBlockStatementAsts(json: Value): List[Ast] = - val blockStmts = sortBlockStatements(json.arr.map(createBabelNodeInfo).toList) - val blockAsts = blockStmts.map(stmt => astForNodeWithFunctionReferenceAndCall(stmt.json)) - setArgumentIndices(blockAsts) - blockAsts - - protected def astForWithStatement(withStatement: BabelNodeInfo): Ast = - val blockNode = createBlockNode(withStatement) - scope.pushNewBlockScope(blockNode) - localAstParentStack.push(blockNode) - val objectAst = astForNodeWithFunctionReferenceAndCall(withStatement.json("object")) - val bodyNodeInfo = createBabelNodeInfo(withStatement.json("body")) - val bodyAsts = bodyNodeInfo.node match - case BlockStatement => createBlockStatementAsts(bodyNodeInfo.json("body")) - case _ => List(astForNodeWithFunctionReferenceAndCall(bodyNodeInfo.json)) - val blockStatementAsts = objectAst +: bodyAsts - setArgumentIndices(blockStatementAsts) - localAstParentStack.pop() - scope.popScope() - blockAst(blockNode, blockStatementAsts) - - protected def astForBlockStatement(block: BabelNodeInfo): Ast = - val blockNode = createBlockNode(block) - scope.pushNewBlockScope(blockNode) - localAstParentStack.push(blockNode) - val blockStatementAsts = createBlockStatementAsts(block.json("body")) - setArgumentIndices(blockStatementAsts) - localAstParentStack.pop() - scope.popScope() - blockAst(blockNode, blockStatementAsts) - - protected def astForReturnStatement(ret: BabelNodeInfo): Ast = - val retCode = ret.code.stripSuffix(";") - val retNode = returnNode(ret, retCode) - safeObj(ret.json, "argument") - .map { argument => - val argAst = astForNodeWithFunctionReference(Obj(argument)) - returnAst(retNode, List(argAst)) - } - .getOrElse(Ast(retNode)) - - private def astForCatchClause(catchClause: BabelNodeInfo): Ast = - astForNodeWithFunctionReference(catchClause.json("body")) - - protected def astForTryStatement(tryStmt: BabelNodeInfo): Ast = - val tryNode = createControlStructureNode(tryStmt, ControlStructureTypes.TRY) - val bodyAst = astForNodeWithFunctionReference(tryStmt.json("block")) - val catchAst = safeObj(tryStmt.json, "handler") - .map { handler => - astForCatchClause(createBabelNodeInfo(Obj(handler))) - } - .getOrElse(Ast()) - val finalizerAst = safeObj(tryStmt.json, "finalizer") - .map { finalizer => - astForNodeWithFunctionReference(Obj(finalizer)) - } - .getOrElse(Ast()) - // The semantics of try statement children is defined by there order value. - // Thus we set the here explicitly and do not rely on the usual consecutive - // ordering. - setOrderExplicitly(bodyAst, 1) - setOrderExplicitly(catchAst, 2) - setOrderExplicitly(finalizerAst, 3) - Ast(tryNode).withChildren(List(bodyAst, catchAst, finalizerAst)) - end astForTryStatement - - def astForIfStatement(ifStmt: BabelNodeInfo): Ast = - val ifNode = createControlStructureNode(ifStmt, ControlStructureTypes.IF) - val testAst = astForNodeWithFunctionReference(ifStmt.json("test")) - val consequentAst = astForNodeWithFunctionReference(ifStmt.json("consequent")) - val alternateAst = safeObj(ifStmt.json, "alternate") - .map { alternate => - astForNodeWithFunctionReference(Obj(alternate)) - } - .getOrElse(Ast()) - // The semantics of if statement children is partially defined by there order value. - // The consequentAst must have order == 2 and alternateAst must have order == 3. - // Only to avoid collision we set testAst to 1 - // because the semantics of it is already indicated via the condition edge. - setOrderExplicitly(testAst, 1) - setOrderExplicitly(consequentAst, 2) - setOrderExplicitly(alternateAst, 3) - Ast(ifNode) - .withChild(testAst) - .withConditionEdge(ifNode, testAst.nodes.head) - .withChild(consequentAst) - .withChild(alternateAst) - end astForIfStatement - - protected def astForDoWhileStatement(doWhileStmt: BabelNodeInfo): Ast = - val whileNode = createControlStructureNode(doWhileStmt, ControlStructureTypes.DO) - val testAst = astForNodeWithFunctionReference(doWhileStmt.json("test")) - val bodyAst = astForNodeWithFunctionReference(doWhileStmt.json("body")) - // The semantics of do-while statement children is partially defined by there order value. - // The bodyAst must have order == 1. Only to avoid collision we set testAst to 2 - // because the semantics of it is already indicated via the condition edge. - setOrderExplicitly(bodyAst, 1) - setOrderExplicitly(testAst, 2) - Ast(whileNode).withChild(bodyAst).withChild(testAst).withConditionEdge( - whileNode, - testAst.nodes.head + private def astForCatchClause(catchClause: BabelNodeInfo): Ast = + astForNodeWithFunctionReference(catchClause.json("body")) + + protected def astForTryStatement(tryStmt: BabelNodeInfo): Ast = + val tryNode = createControlStructureNode(tryStmt, ControlStructureTypes.TRY) + val bodyAst = astForNodeWithFunctionReference(tryStmt.json("block")) + val catchAst = safeObj(tryStmt.json, "handler") + .map { handler => + astForCatchClause(createBabelNodeInfo(Obj(handler))) + } + .getOrElse(Ast()) + val finalizerAst = safeObj(tryStmt.json, "finalizer") + .map { finalizer => + astForNodeWithFunctionReference(Obj(finalizer)) + } + .getOrElse(Ast()) + // The semantics of try statement children is defined by there order value. + // Thus we set the here explicitly and do not rely on the usual consecutive + // ordering. + setOrderExplicitly(bodyAst, 1) + setOrderExplicitly(catchAst, 2) + setOrderExplicitly(finalizerAst, 3) + Ast(tryNode).withChildren(List(bodyAst, catchAst, finalizerAst)) + end astForTryStatement + + def astForIfStatement(ifStmt: BabelNodeInfo): Ast = + val ifNode = createControlStructureNode(ifStmt, ControlStructureTypes.IF) + val testAst = astForNodeWithFunctionReference(ifStmt.json("test")) + val consequentAst = astForNodeWithFunctionReference(ifStmt.json("consequent")) + val alternateAst = safeObj(ifStmt.json, "alternate") + .map { alternate => + astForNodeWithFunctionReference(Obj(alternate)) + } + .getOrElse(Ast()) + // The semantics of if statement children is partially defined by there order value. + // The consequentAst must have order == 2 and alternateAst must have order == 3. + // Only to avoid collision we set testAst to 1 + // because the semantics of it is already indicated via the condition edge. + setOrderExplicitly(testAst, 1) + setOrderExplicitly(consequentAst, 2) + setOrderExplicitly(alternateAst, 3) + Ast(ifNode) + .withChild(testAst) + .withConditionEdge(ifNode, testAst.nodes.head) + .withChild(consequentAst) + .withChild(alternateAst) + end astForIfStatement + + protected def astForDoWhileStatement(doWhileStmt: BabelNodeInfo): Ast = + val whileNode = createControlStructureNode(doWhileStmt, ControlStructureTypes.DO) + val testAst = astForNodeWithFunctionReference(doWhileStmt.json("test")) + val bodyAst = astForNodeWithFunctionReference(doWhileStmt.json("body")) + // The semantics of do-while statement children is partially defined by there order value. + // The bodyAst must have order == 1. Only to avoid collision we set testAst to 2 + // because the semantics of it is already indicated via the condition edge. + setOrderExplicitly(bodyAst, 1) + setOrderExplicitly(testAst, 2) + Ast(whileNode).withChild(bodyAst).withChild(testAst).withConditionEdge( + whileNode, + testAst.nodes.head + ) + + protected def astForWhileStatement(whileStmt: BabelNodeInfo): Ast = + val whileNode = createControlStructureNode(whileStmt, ControlStructureTypes.WHILE) + val testAst = astForNodeWithFunctionReference(whileStmt.json("test")) + val bodyAst = astForNodeWithFunctionReference(whileStmt.json("body")) + // The semantics of while statement children is partially defined by there order value. + // The bodyAst must have order == 2. Only to avoid collision we set testAst to 1 + // because the semantics of it is already indicated via the condition edge. + setOrderExplicitly(testAst, 1) + setOrderExplicitly(bodyAst, 2) + Ast(whileNode).withChild(testAst).withConditionEdge( + whileNode, + testAst.nodes.head + ).withChild(bodyAst) + + protected def astForForStatement(forStmt: BabelNodeInfo): Ast = + val forNode = createControlStructureNode(forStmt, ControlStructureTypes.FOR) + val initAst = safeObj(forStmt.json, "init") + .map { init => + astForNodeWithFunctionReference(Obj(init)) + } + .getOrElse(Ast()) + val testAst = safeObj(forStmt.json, "test") + .map { test => + astForNodeWithFunctionReference(Obj(test)) + } + .getOrElse(Ast(literalNode(forStmt, "true", Option(Defines.Boolean)))) + val updateAst = safeObj(forStmt.json, "update") + .map { update => + astForNodeWithFunctionReference(Obj(update)) + } + .getOrElse(Ast()) + val bodyAst = astForNodeWithFunctionReference(forStmt.json("body")) + + // The semantics of for statement children is defined by there order value. + // Thus we set the here explicitly and do not rely on the usual consecutive + // ordering. + setOrderExplicitly(initAst, 1) + setOrderExplicitly(testAst, 2) + setOrderExplicitly(updateAst, 3) + setOrderExplicitly(bodyAst, 4) + Ast(forNode).withChild(initAst).withChild(testAst).withChild(updateAst).withChild(bodyAst) + end astForForStatement + + protected def astForLabeledStatement(labelStmt: BabelNodeInfo): Ast = + val labelName = code(labelStmt.json("label")) + val labeledNode = NewJumpTarget() + .parserTypeName(labelStmt.node.toString) + .name(labelName) + .code(s"$labelName:") + .lineNumber(labelStmt.lineNumber) + .columnNumber(labelStmt.columnNumber) + + val blockNode = createBlockNode(labelStmt) + scope.pushNewBlockScope(blockNode) + localAstParentStack.push(blockNode) + val bodyAst = astForNodeWithFunctionReference(labelStmt.json("body")) + scope.popScope() + localAstParentStack.pop() + + val labelAsts = List(Ast(labeledNode), bodyAst) + setArgumentIndices(labelAsts) + blockAst(blockNode, labelAsts) + end astForLabeledStatement + + protected def astForBreakStatement(breakStmt: BabelNodeInfo): Ast = + val labelAst = safeObj(breakStmt.json, "label") + .map { label => + val labelNode = Obj(label) + val labelCode = code(labelNode) + Ast( + NewJumpLabel() + .parserTypeName(breakStmt.node.toString) + .name(labelCode) + .code(labelCode) + .lineNumber(breakStmt.lineNumber) + .columnNumber(breakStmt.columnNumber) + .order(1) + ) + } + .getOrElse(Ast()) + Ast(createControlStructureNode(breakStmt, ControlStructureTypes.BREAK)).withChild(labelAst) + + protected def astForContinueStatement(continueStmt: BabelNodeInfo): Ast = + val labelAst = safeObj(continueStmt.json, "label") + .map { label => + val labelNode = Obj(label) + val labelCode = code(labelNode) + Ast( + NewJumpLabel() + .parserTypeName(continueStmt.node.toString) + .name(labelCode) + .code(labelCode) + .lineNumber(continueStmt.lineNumber) + .columnNumber(continueStmt.columnNumber) + .order(1) + ) + } + .getOrElse(Ast()) + Ast(createControlStructureNode(continueStmt, ControlStructureTypes.CONTINUE)).withChild( + labelAst + ) + end astForContinueStatement + + protected def astForThrowStatement(throwStmt: BabelNodeInfo): Ast = + val argumentAst = astForNodeWithFunctionReference(throwStmt.json("argument")) + val throwCallNode = + callNode(throwStmt, throwStmt.code, ".throw", DispatchTypes.STATIC_DISPATCH) + val argAsts = List(argumentAst) + callAst(throwCallNode, argAsts) + + private def astsForSwitchCase(switchCase: BabelNodeInfo): List[Ast] = + val labelAst = Ast(createJumpTarget(switchCase)) + val testAsts = safeObj(switchCase.json, "test").map(t => + astForNodeWithFunctionReference(Obj(t)) + ).toList + val consequentAsts = astForNodes(switchCase.json("consequent").arr.toList) + labelAst +: (testAsts ++ consequentAsts) + + protected def astForSwitchStatement(switchStmt: BabelNodeInfo): Ast = + val switchNode = createControlStructureNode(switchStmt, ControlStructureTypes.SWITCH) + + // The semantics of switch statement children is partially defined by there order value. + // The blockAst must have order == 2. Only to avoid collision we set switchExpressionAst to 1 + // because the semantics of it is already indicated via the condition edge. + val switchExpressionAst = astForNodeWithFunctionReference(switchStmt.json("discriminant")) + setOrderExplicitly(switchExpressionAst, 1) + + val blockNode = createBlockNode(switchStmt).order(2) + scope.pushNewBlockScope(blockNode) + localAstParentStack.push(blockNode) + + val casesAsts = + switchStmt.json("cases").arr.flatMap(c => astsForSwitchCase(createBabelNodeInfo(c))) + setArgumentIndices(casesAsts.toList) + + scope.popScope() + localAstParentStack.pop() + + Ast(switchNode) + .withChild(switchExpressionAst) + .withConditionEdge(switchNode, switchExpressionAst.nodes.head) + .withChild(blockAst(blockNode, casesAsts.toList)) + end astForSwitchStatement + + /** De-sugaring from: + * + * for (var i in/of arr) { body } + * + * to: + * + * { var _iterator = .iterator(arr); var _result; var i; while (!(_result = + * _iterator.next()).done) { i = _result.value; body } } + */ + private def astForInOfStatementWithIdentifier( + forInOfStmt: BabelNodeInfo, + idNodeInfo: BabelNodeInfo + ): Ast = + // surrounding block: + val blockNode = createBlockNode(forInOfStmt) + scope.pushNewBlockScope(blockNode) + localAstParentStack.push(blockNode) + + val collection = forInOfStmt.json("right") + val collectionName = code(collection) + + // _iterator assignment: + val iteratorName = generateUnusedVariableName(usedVariableNames, "_iterator") + val iteratorLocalNode = newLocalNode(iteratorName, Defines.Any).order(0) + val iteratorNode = identifierNode(forInOfStmt, iteratorName) + diffGraph.addEdge(localAstParentStack.head, iteratorLocalNode, EdgeTypes.AST) + scope.addVariableReference(iteratorName, iteratorNode) + + val iteratorCall = + // TODO: add operator to schema + callNode( + forInOfStmt, + s".iterator($collectionName)", + ".iterator", + DispatchTypes.STATIC_DISPATCH ) - protected def astForWhileStatement(whileStmt: BabelNodeInfo): Ast = - val whileNode = createControlStructureNode(whileStmt, ControlStructureTypes.WHILE) - val testAst = astForNodeWithFunctionReference(whileStmt.json("test")) - val bodyAst = astForNodeWithFunctionReference(whileStmt.json("body")) - // The semantics of while statement children is partially defined by there order value. - // The bodyAst must have order == 2. Only to avoid collision we set testAst to 1 - // because the semantics of it is already indicated via the condition edge. - setOrderExplicitly(testAst, 1) - setOrderExplicitly(bodyAst, 2) - Ast(whileNode).withChild(testAst).withConditionEdge( - whileNode, - testAst.nodes.head - ).withChild(bodyAst) - - protected def astForForStatement(forStmt: BabelNodeInfo): Ast = - val forNode = createControlStructureNode(forStmt, ControlStructureTypes.FOR) - val initAst = safeObj(forStmt.json, "init") - .map { init => - astForNodeWithFunctionReference(Obj(init)) - } - .getOrElse(Ast()) - val testAst = safeObj(forStmt.json, "test") - .map { test => - astForNodeWithFunctionReference(Obj(test)) - } - .getOrElse(Ast(literalNode(forStmt, "true", Option(Defines.Boolean)))) - val updateAst = safeObj(forStmt.json, "update") - .map { update => - astForNodeWithFunctionReference(Obj(update)) - } - .getOrElse(Ast()) - val bodyAst = astForNodeWithFunctionReference(forStmt.json("body")) - - // The semantics of for statement children is defined by there order value. - // Thus we set the here explicitly and do not rely on the usual consecutive - // ordering. - setOrderExplicitly(initAst, 1) - setOrderExplicitly(testAst, 2) - setOrderExplicitly(updateAst, 3) - setOrderExplicitly(bodyAst, 4) - Ast(forNode).withChild(initAst).withChild(testAst).withChild(updateAst).withChild(bodyAst) - end astForForStatement - - protected def astForLabeledStatement(labelStmt: BabelNodeInfo): Ast = - val labelName = code(labelStmt.json("label")) - val labeledNode = NewJumpTarget() - .parserTypeName(labelStmt.node.toString) - .name(labelName) - .code(s"$labelName:") - .lineNumber(labelStmt.lineNumber) - .columnNumber(labelStmt.columnNumber) - - val blockNode = createBlockNode(labelStmt) - scope.pushNewBlockScope(blockNode) - localAstParentStack.push(blockNode) - val bodyAst = astForNodeWithFunctionReference(labelStmt.json("body")) - scope.popScope() - localAstParentStack.pop() - - val labelAsts = List(Ast(labeledNode), bodyAst) - setArgumentIndices(labelAsts) - blockAst(blockNode, labelAsts) - end astForLabeledStatement - - protected def astForBreakStatement(breakStmt: BabelNodeInfo): Ast = - val labelAst = safeObj(breakStmt.json, "label") - .map { label => - val labelNode = Obj(label) - val labelCode = code(labelNode) - Ast( - NewJumpLabel() - .parserTypeName(breakStmt.node.toString) - .name(labelCode) - .code(labelCode) - .lineNumber(breakStmt.lineNumber) - .columnNumber(breakStmt.columnNumber) - .order(1) - ) - } - .getOrElse(Ast()) - Ast(createControlStructureNode(breakStmt, ControlStructureTypes.BREAK)).withChild(labelAst) - - protected def astForContinueStatement(continueStmt: BabelNodeInfo): Ast = - val labelAst = safeObj(continueStmt.json, "label") - .map { label => - val labelNode = Obj(label) - val labelCode = code(labelNode) - Ast( - NewJumpLabel() - .parserTypeName(continueStmt.node.toString) - .name(labelCode) - .code(labelCode) - .lineNumber(continueStmt.lineNumber) - .columnNumber(continueStmt.columnNumber) - .order(1) - ) - } - .getOrElse(Ast()) - Ast(createControlStructureNode(continueStmt, ControlStructureTypes.CONTINUE)).withChild( - labelAst + val objectKeysCallArgs = List(astForNodeWithFunctionReference(collection)) + val objectKeysCallAst = callAst(iteratorCall, objectKeysCallArgs) + + val iteratorAssignmentNode = + callNode( + forInOfStmt, + s"$iteratorName = .iterator($collectionName)", + Operators.assignment, + DispatchTypes.STATIC_DISPATCH ) - end astForContinueStatement - - protected def astForThrowStatement(throwStmt: BabelNodeInfo): Ast = - val argumentAst = astForNodeWithFunctionReference(throwStmt.json("argument")) - val throwCallNode = - callNode(throwStmt, throwStmt.code, ".throw", DispatchTypes.STATIC_DISPATCH) - val argAsts = List(argumentAst) - callAst(throwCallNode, argAsts) - - private def astsForSwitchCase(switchCase: BabelNodeInfo): List[Ast] = - val labelAst = Ast(createJumpTarget(switchCase)) - val testAsts = safeObj(switchCase.json, "test").map(t => - astForNodeWithFunctionReference(Obj(t)) - ).toList - val consequentAsts = astForNodes(switchCase.json("consequent").arr.toList) - labelAst +: (testAsts ++ consequentAsts) - - protected def astForSwitchStatement(switchStmt: BabelNodeInfo): Ast = - val switchNode = createControlStructureNode(switchStmt, ControlStructureTypes.SWITCH) - - // The semantics of switch statement children is partially defined by there order value. - // The blockAst must have order == 2. Only to avoid collision we set switchExpressionAst to 1 - // because the semantics of it is already indicated via the condition edge. - val switchExpressionAst = astForNodeWithFunctionReference(switchStmt.json("discriminant")) - setOrderExplicitly(switchExpressionAst, 1) - - val blockNode = createBlockNode(switchStmt).order(2) - scope.pushNewBlockScope(blockNode) - localAstParentStack.push(blockNode) - - val casesAsts = - switchStmt.json("cases").arr.flatMap(c => astsForSwitchCase(createBabelNodeInfo(c))) - setArgumentIndices(casesAsts.toList) - - scope.popScope() - localAstParentStack.pop() - - Ast(switchNode) - .withChild(switchExpressionAst) - .withConditionEdge(switchNode, switchExpressionAst.nodes.head) - .withChild(blockAst(blockNode, casesAsts.toList)) - end astForSwitchStatement - - /** De-sugaring from: - * - * for (var i in/of arr) { body } - * - * to: - * - * { var _iterator = .iterator(arr); var _result; var i; while (!(_result = - * _iterator.next()).done) { i = _result.value; body } } - */ - private def astForInOfStatementWithIdentifier( - forInOfStmt: BabelNodeInfo, - idNodeInfo: BabelNodeInfo - ): Ast = - // surrounding block: - val blockNode = createBlockNode(forInOfStmt) - scope.pushNewBlockScope(blockNode) - localAstParentStack.push(blockNode) - - val collection = forInOfStmt.json("right") - val collectionName = code(collection) - - // _iterator assignment: - val iteratorName = generateUnusedVariableName(usedVariableNames, "_iterator") - val iteratorLocalNode = newLocalNode(iteratorName, Defines.Any).order(0) - val iteratorNode = identifierNode(forInOfStmt, iteratorName) - diffGraph.addEdge(localAstParentStack.head, iteratorLocalNode, EdgeTypes.AST) - scope.addVariableReference(iteratorName, iteratorNode) - - val iteratorCall = - // TODO: add operator to schema - callNode( - forInOfStmt, - s".iterator($collectionName)", - ".iterator", - DispatchTypes.STATIC_DISPATCH - ) - val objectKeysCallArgs = List(astForNodeWithFunctionReference(collection)) - val objectKeysCallAst = callAst(iteratorCall, objectKeysCallArgs) + val iteratorAssignmentArgs = List(Ast(iteratorNode), objectKeysCallAst) + val iteratorAssignmentAst = callAst(iteratorAssignmentNode, iteratorAssignmentArgs) - val iteratorAssignmentNode = - callNode( - forInOfStmt, - s"$iteratorName = .iterator($collectionName)", - Operators.assignment, - DispatchTypes.STATIC_DISPATCH - ) + // _result: + val resultName = generateUnusedVariableName(usedVariableNames, "_result") + val resultLocalNode = newLocalNode(resultName, Defines.Any).order(0) + val resultNode = identifierNode(forInOfStmt, resultName) + diffGraph.addEdge(localAstParentStack.head, resultLocalNode, EdgeTypes.AST) + scope.addVariableReference(resultName, resultNode) - val iteratorAssignmentArgs = List(Ast(iteratorNode), objectKeysCallAst) - val iteratorAssignmentAst = callAst(iteratorAssignmentNode, iteratorAssignmentArgs) + // loop variable: + val loopVariableName = idNodeInfo.code - // _result: - val resultName = generateUnusedVariableName(usedVariableNames, "_result") - val resultLocalNode = newLocalNode(resultName, Defines.Any).order(0) - val resultNode = identifierNode(forInOfStmt, resultName) - diffGraph.addEdge(localAstParentStack.head, resultLocalNode, EdgeTypes.AST) - scope.addVariableReference(resultName, resultNode) + val loopVariableLocalNode = newLocalNode(loopVariableName, Defines.Any).order(0) + val loopVariableNode = identifierNode(forInOfStmt, loopVariableName) + diffGraph.addEdge(localAstParentStack.head, loopVariableLocalNode, EdgeTypes.AST) + scope.addVariableReference(loopVariableName, loopVariableNode) - // loop variable: - val loopVariableName = idNodeInfo.code + // while loop: + val whileLoopNode = createControlStructureNode(forInOfStmt, ControlStructureTypes.WHILE) - val loopVariableLocalNode = newLocalNode(loopVariableName, Defines.Any).order(0) - val loopVariableNode = identifierNode(forInOfStmt, loopVariableName) - diffGraph.addEdge(localAstParentStack.head, loopVariableLocalNode, EdgeTypes.AST) - scope.addVariableReference(loopVariableName, loopVariableNode) + // while loop test: + val testCallNode = + callNode( + forInOfStmt, + s"!($resultName = $iteratorName.next()).done", + Operators.not, + DispatchTypes.STATIC_DISPATCH + ) - // while loop: - val whileLoopNode = createControlStructureNode(forInOfStmt, ControlStructureTypes.WHILE) + val doneBaseNode = + callNode( + forInOfStmt, + s"($resultName = $iteratorName.next())", + Operators.assignment, + DispatchTypes.STATIC_DISPATCH + ) - // while loop test: - val testCallNode = - callNode( - forInOfStmt, - s"!($resultName = $iteratorName.next()).done", - Operators.not, - DispatchTypes.STATIC_DISPATCH - ) + val lhsNode = identifierNode(forInOfStmt, resultName) - val doneBaseNode = - callNode( - forInOfStmt, - s"($resultName = $iteratorName.next())", - Operators.assignment, - DispatchTypes.STATIC_DISPATCH - ) + val rhsNode = + callNode(forInOfStmt, s"$iteratorName.next()", "next", DispatchTypes.DYNAMIC_DISPATCH) - val lhsNode = identifierNode(forInOfStmt, resultName) + val nextBaseNode = identifierNode(forInOfStmt, iteratorName) - val rhsNode = - callNode(forInOfStmt, s"$iteratorName.next()", "next", DispatchTypes.DYNAMIC_DISPATCH) + val nextMemberNode = + createFieldIdentifierNode("next", forInOfStmt.lineNumber, forInOfStmt.columnNumber) - val nextBaseNode = identifierNode(forInOfStmt, iteratorName) + val nextReceiverNode = + createFieldAccessCallAst( + nextBaseNode, + nextMemberNode, + forInOfStmt.lineNumber, + forInOfStmt.columnNumber + ) - val nextMemberNode = - createFieldIdentifierNode("next", forInOfStmt.lineNumber, forInOfStmt.columnNumber) + val thisNextNode = identifierNode(forInOfStmt, iteratorName) - val nextReceiverNode = - createFieldAccessCallAst( - nextBaseNode, - nextMemberNode, - forInOfStmt.lineNumber, - forInOfStmt.columnNumber - ) + val rhsArgs = List(Ast(thisNextNode)) + val rhsAst = callAst(rhsNode, rhsArgs, receiver = Option(nextReceiverNode)) - val thisNextNode = identifierNode(forInOfStmt, iteratorName) + val doneBaseArgs = List(Ast(lhsNode), rhsAst) + val doneBaseAst = callAst(doneBaseNode, doneBaseArgs) + Ast.storeInDiffGraph(doneBaseAst, diffGraph) - val rhsArgs = List(Ast(thisNextNode)) - val rhsAst = callAst(rhsNode, rhsArgs, receiver = Option(nextReceiverNode)) + val doneMemberNode = + createFieldIdentifierNode("done", forInOfStmt.lineNumber, forInOfStmt.columnNumber) - val doneBaseArgs = List(Ast(lhsNode), rhsAst) - val doneBaseAst = callAst(doneBaseNode, doneBaseArgs) - Ast.storeInDiffGraph(doneBaseAst, diffGraph) + val testNode = + createFieldAccessCallAst( + doneBaseNode, + doneMemberNode, + forInOfStmt.lineNumber, + forInOfStmt.columnNumber + ) - val doneMemberNode = - createFieldIdentifierNode("done", forInOfStmt.lineNumber, forInOfStmt.columnNumber) + val testCallArgs = List(testNode) + val testCallAst = callAst(testCallNode, testCallArgs) - val testNode = - createFieldAccessCallAst( - doneBaseNode, - doneMemberNode, - forInOfStmt.lineNumber, - forInOfStmt.columnNumber - ) + val whileLoopAst = + Ast(whileLoopNode).withChild(testCallAst).withConditionEdge(whileLoopNode, testCallNode) - val testCallArgs = List(testNode) - val testCallAst = callAst(testCallNode, testCallArgs) + // while loop variable assignment: + val whileLoopVariableNode = identifierNode(forInOfStmt, loopVariableName) - val whileLoopAst = - Ast(whileLoopNode).withChild(testCallAst).withConditionEdge(whileLoopNode, testCallNode) + val baseNode = identifierNode(forInOfStmt, resultName) - // while loop variable assignment: - val whileLoopVariableNode = identifierNode(forInOfStmt, loopVariableName) + val memberNode = + createFieldIdentifierNode("value", forInOfStmt.lineNumber, forInOfStmt.columnNumber) - val baseNode = identifierNode(forInOfStmt, resultName) + val accessAst = createFieldAccessCallAst( + baseNode, + memberNode, + forInOfStmt.lineNumber, + forInOfStmt.columnNumber + ) - val memberNode = - createFieldIdentifierNode("value", forInOfStmt.lineNumber, forInOfStmt.columnNumber) + val loopVariableAssignmentNode = callNode( + forInOfStmt, + s"$loopVariableName = $resultName.value", + Operators.assignment, + DispatchTypes.STATIC_DISPATCH + ) - val accessAst = createFieldAccessCallAst( - baseNode, - memberNode, - forInOfStmt.lineNumber, - forInOfStmt.columnNumber + val loopVariableAssignmentArgs = List(Ast(whileLoopVariableNode), accessAst) + val loopVariableAssignmentAst = + callAst(loopVariableAssignmentNode, loopVariableAssignmentArgs) + + val whileLoopBlockNode = createBlockNode(forInOfStmt) + scope.pushNewBlockScope(whileLoopBlockNode) + localAstParentStack.push(whileLoopBlockNode) + + // while loop block: + val bodyAst = astForNodeWithFunctionReference(forInOfStmt.json("body")) + + val whileLoopBlockChildren = List(loopVariableAssignmentAst, bodyAst) + setArgumentIndices(whileLoopBlockChildren) + val whileLoopBlockAst = blockAst(whileLoopBlockNode, whileLoopBlockChildren) + + scope.popScope() + localAstParentStack.pop() + + // end surrounding block: + scope.popScope() + localAstParentStack.pop() + + val blockChildren = + List( + iteratorAssignmentAst, + Ast(resultNode), + Ast(loopVariableNode), + whileLoopAst.withChild(whileLoopBlockAst) + ) + setArgumentIndices(blockChildren) + blockAst(blockNode, blockChildren) + end astForInOfStatementWithIdentifier + + /** De-sugaring from: + * + * for (expr in/of arr) { body } + * + * to: + * + * { var _iterator = .iterator(arr); var _result; while (!(_result = + * _iterator.next()).done) { expr = _result.value; body } } + */ + private def astForInOfStatementWithExpression( + forInOfStmt: BabelNodeInfo, + idNodeInfo: BabelNodeInfo + ): Ast = + // surrounding block: + val blockNode = createBlockNode(forInOfStmt) + scope.pushNewBlockScope(blockNode) + localAstParentStack.push(blockNode) + + val collection = forInOfStmt.json("right") + val collectionName = code(collection) + + // _iterator assignment: + val iteratorName = generateUnusedVariableName(usedVariableNames, "_iterator") + val iteratorLocalNode = newLocalNode(iteratorName, Defines.Any).order(0) + val iteratorNode = identifierNode(forInOfStmt, iteratorName) + diffGraph.addEdge(localAstParentStack.head, iteratorLocalNode, EdgeTypes.AST) + scope.addVariableReference(iteratorName, iteratorNode) + + val iteratorCall = + // TODO: add operator to schema + callNode( + forInOfStmt, + s".iterator($collectionName)", + ".iterator", + DispatchTypes.STATIC_DISPATCH ) - val loopVariableAssignmentNode = callNode( + val objectKeysCallArgs = List(astForNodeWithFunctionReference(collection)) + val objectKeysCallAst = callAst(iteratorCall, objectKeysCallArgs) + + val iteratorAssignmentNode = + callNode( forInOfStmt, - s"$loopVariableName = $resultName.value", + s"$iteratorName = .iterator($collectionName)", Operators.assignment, DispatchTypes.STATIC_DISPATCH ) - val loopVariableAssignmentArgs = List(Ast(whileLoopVariableNode), accessAst) - val loopVariableAssignmentAst = - callAst(loopVariableAssignmentNode, loopVariableAssignmentArgs) + val iteratorAssignmentArgs = List(Ast(iteratorNode), objectKeysCallAst) + val iteratorAssignmentAst = callAst(iteratorAssignmentNode, iteratorAssignmentArgs) - val whileLoopBlockNode = createBlockNode(forInOfStmt) - scope.pushNewBlockScope(whileLoopBlockNode) - localAstParentStack.push(whileLoopBlockNode) + // _result: + val resultName = generateUnusedVariableName(usedVariableNames, "_result") + val resultLocalNode = newLocalNode(resultName, Defines.Any).order(0) + val resultNode = identifierNode(forInOfStmt, resultName) + diffGraph.addEdge(localAstParentStack.head, resultLocalNode, EdgeTypes.AST) + scope.addVariableReference(resultName, resultNode) - // while loop block: - val bodyAst = astForNodeWithFunctionReference(forInOfStmt.json("body")) + // while loop: + val whileLoopNode = createControlStructureNode(forInOfStmt, ControlStructureTypes.WHILE) - val whileLoopBlockChildren = List(loopVariableAssignmentAst, bodyAst) - setArgumentIndices(whileLoopBlockChildren) - val whileLoopBlockAst = blockAst(whileLoopBlockNode, whileLoopBlockChildren) + // while loop test: + val testCallNode = + callNode( + forInOfStmt, + s"!($resultName = $iteratorName.next()).done", + Operators.not, + DispatchTypes.STATIC_DISPATCH + ) - scope.popScope() - localAstParentStack.pop() + val doneBaseNode = + callNode( + forInOfStmt, + s"($resultName = $iteratorName.next())", + Operators.assignment, + DispatchTypes.STATIC_DISPATCH + ) - // end surrounding block: - scope.popScope() - localAstParentStack.pop() + val lhsNode = identifierNode(forInOfStmt, resultName) - val blockChildren = - List( - iteratorAssignmentAst, - Ast(resultNode), - Ast(loopVariableNode), - whileLoopAst.withChild(whileLoopBlockAst) - ) - setArgumentIndices(blockChildren) - blockAst(blockNode, blockChildren) - end astForInOfStatementWithIdentifier - - /** De-sugaring from: - * - * for (expr in/of arr) { body } - * - * to: - * - * { var _iterator = .iterator(arr); var _result; while (!(_result = - * _iterator.next()).done) { expr = _result.value; body } } - */ - private def astForInOfStatementWithExpression( - forInOfStmt: BabelNodeInfo, - idNodeInfo: BabelNodeInfo - ): Ast = - // surrounding block: - val blockNode = createBlockNode(forInOfStmt) - scope.pushNewBlockScope(blockNode) - localAstParentStack.push(blockNode) - - val collection = forInOfStmt.json("right") - val collectionName = code(collection) - - // _iterator assignment: - val iteratorName = generateUnusedVariableName(usedVariableNames, "_iterator") - val iteratorLocalNode = newLocalNode(iteratorName, Defines.Any).order(0) - val iteratorNode = identifierNode(forInOfStmt, iteratorName) - diffGraph.addEdge(localAstParentStack.head, iteratorLocalNode, EdgeTypes.AST) - scope.addVariableReference(iteratorName, iteratorNode) - - val iteratorCall = - // TODO: add operator to schema - callNode( - forInOfStmt, - s".iterator($collectionName)", - ".iterator", - DispatchTypes.STATIC_DISPATCH - ) + val rhsNode = + callNode(forInOfStmt, s"$iteratorName.next()", "next", DispatchTypes.DYNAMIC_DISPATCH) - val objectKeysCallArgs = List(astForNodeWithFunctionReference(collection)) - val objectKeysCallAst = callAst(iteratorCall, objectKeysCallArgs) + val nextBaseNode = identifierNode(forInOfStmt, iteratorName) - val iteratorAssignmentNode = - callNode( - forInOfStmt, - s"$iteratorName = .iterator($collectionName)", - Operators.assignment, - DispatchTypes.STATIC_DISPATCH - ) + val nextMemberNode = + createFieldIdentifierNode("next", forInOfStmt.lineNumber, forInOfStmt.columnNumber) - val iteratorAssignmentArgs = List(Ast(iteratorNode), objectKeysCallAst) - val iteratorAssignmentAst = callAst(iteratorAssignmentNode, iteratorAssignmentArgs) + val nextReceiverNode = + createFieldAccessCallAst( + nextBaseNode, + nextMemberNode, + forInOfStmt.lineNumber, + forInOfStmt.columnNumber + ) - // _result: - val resultName = generateUnusedVariableName(usedVariableNames, "_result") - val resultLocalNode = newLocalNode(resultName, Defines.Any).order(0) - val resultNode = identifierNode(forInOfStmt, resultName) - diffGraph.addEdge(localAstParentStack.head, resultLocalNode, EdgeTypes.AST) - scope.addVariableReference(resultName, resultNode) + val thisNextNode = identifierNode(forInOfStmt, iteratorName) - // while loop: - val whileLoopNode = createControlStructureNode(forInOfStmt, ControlStructureTypes.WHILE) + val rhsArgs = List(Ast(thisNextNode)) + val rhsAst = callAst(rhsNode, rhsArgs, receiver = Option(nextReceiverNode)) - // while loop test: - val testCallNode = - callNode( - forInOfStmt, - s"!($resultName = $iteratorName.next()).done", - Operators.not, - DispatchTypes.STATIC_DISPATCH - ) + val doneBaseArgs = List(Ast(lhsNode), rhsAst) + val doneBaseAst = callAst(doneBaseNode, doneBaseArgs) + Ast.storeInDiffGraph(doneBaseAst, diffGraph) - val doneBaseNode = - callNode( - forInOfStmt, - s"($resultName = $iteratorName.next())", - Operators.assignment, - DispatchTypes.STATIC_DISPATCH - ) + val doneMemberNode = + createFieldIdentifierNode("done", forInOfStmt.lineNumber, forInOfStmt.columnNumber) - val lhsNode = identifierNode(forInOfStmt, resultName) + val testNode = + createFieldAccessCallAst( + doneBaseNode, + doneMemberNode, + forInOfStmt.lineNumber, + forInOfStmt.columnNumber + ) - val rhsNode = - callNode(forInOfStmt, s"$iteratorName.next()", "next", DispatchTypes.DYNAMIC_DISPATCH) + val testCallArgs = List(testNode) + val testCallAst = callAst(testCallNode, testCallArgs) + + val whileLoopAst = + Ast(whileLoopNode).withChild(testCallAst).withConditionEdge(whileLoopNode, testCallNode) + + // while loop variable assignment: + val whileLoopVariableNode = astForNode(idNodeInfo.json) + + val baseNode = identifierNode(forInOfStmt, resultName) + + val memberNode = + createFieldIdentifierNode("value", forInOfStmt.lineNumber, forInOfStmt.columnNumber) + + val accessAst = createFieldAccessCallAst( + baseNode, + memberNode, + forInOfStmt.lineNumber, + forInOfStmt.columnNumber + ) + + val loopVariableAssignmentNode = callNode( + forInOfStmt, + s"${idNodeInfo.code} = $resultName.value", + Operators.assignment, + DispatchTypes.STATIC_DISPATCH + ) + + val loopVariableAssignmentArgs = List(whileLoopVariableNode, accessAst) + val loopVariableAssignmentAst = + callAst(loopVariableAssignmentNode, loopVariableAssignmentArgs) + + val whileLoopBlockNode = createBlockNode(forInOfStmt) + scope.pushNewBlockScope(whileLoopBlockNode) + localAstParentStack.push(whileLoopBlockNode) + + // while loop block: + val bodyAst = astForNodeWithFunctionReference(forInOfStmt.json("body")) + + val whileLoopBlockChildren = List(loopVariableAssignmentAst, bodyAst) + setArgumentIndices(whileLoopBlockChildren) + val whileLoopBlockAst = blockAst(whileLoopBlockNode, whileLoopBlockChildren) + + scope.popScope() + localAstParentStack.pop() + + // end surrounding block: + scope.popScope() + localAstParentStack.pop() + + val blockChildren = + List(iteratorAssignmentAst, Ast(resultNode), whileLoopAst.withChild(whileLoopBlockAst)) + setArgumentIndices(blockChildren) + blockAst(blockNode, blockChildren) + end astForInOfStatementWithExpression + + /** De-sugaring from: + * + * for(var {a, b, c} of obj) { body } + * + * to: + * + * { var _iterator = .iterator(obj); var _result; var a; var b; var c; while (!(_result + * \= _iterator.next()).done) { a = _result.value.a; b = _result.value.b; c = _result.value.c; + * body } } + */ + private def astForInOfStatementWithObject( + forInOfStmt: BabelNodeInfo, + idNodeInfo: BabelNodeInfo + ): Ast = + // surrounding block: + val blockNode = createBlockNode(forInOfStmt) + scope.pushNewBlockScope(blockNode) + localAstParentStack.push(blockNode) + + val collection = forInOfStmt.json("right") + val collectionName = code(collection) + + // _iterator assignment: + val iteratorName = generateUnusedVariableName(usedVariableNames, "_iterator") + val iteratorLocalNode = newLocalNode(iteratorName, Defines.Any).order(0) + val iteratorNode = identifierNode(forInOfStmt, iteratorName) + diffGraph.addEdge(localAstParentStack.head, iteratorLocalNode, EdgeTypes.AST) + scope.addVariableReference(iteratorName, iteratorNode) + // TODO: add operator to schema + val iteratorCall = + callNode( + forInOfStmt, + s".iterator($collectionName)", + ".iterator", + DispatchTypes.STATIC_DISPATCH + ) - val nextBaseNode = identifierNode(forInOfStmt, iteratorName) + val objectKeysCallArgs = List(astForNodeWithFunctionReference(collection)) + val objectKeysCallAst = callAst(iteratorCall, objectKeysCallArgs) - val nextMemberNode = - createFieldIdentifierNode("next", forInOfStmt.lineNumber, forInOfStmt.columnNumber) + val iteratorAssignmentNode = + callNode( + forInOfStmt, + s"$iteratorName = .iterator($collectionName)", + Operators.assignment, + DispatchTypes.STATIC_DISPATCH + ) - val nextReceiverNode = - createFieldAccessCallAst( - nextBaseNode, - nextMemberNode, - forInOfStmt.lineNumber, - forInOfStmt.columnNumber - ) + val iteratorAssignmentArgs = List(Ast(iteratorNode), objectKeysCallAst) + val iteratorAssignmentAst = callAst(iteratorAssignmentNode, iteratorAssignmentArgs) + + // _result: + val resultName = generateUnusedVariableName(usedVariableNames, "_result") + val resultLocalNode = newLocalNode(resultName, Defines.Any).order(0) + val resultNode = identifierNode(forInOfStmt, resultName) + diffGraph.addEdge(localAstParentStack.head, resultLocalNode, EdgeTypes.AST) + scope.addVariableReference(resultName, resultNode) + + // loop variable: + val loopVariableNames = idNodeInfo.json("properties").arr.toList.map(code) + + val loopVariableLocalNodes = loopVariableNames.map(newLocalNode(_, Defines.Any).order(0)) + val loopVariableNodes = loopVariableNames.map(identifierNode(forInOfStmt, _)) + loopVariableLocalNodes.foreach(diffGraph.addEdge( + localAstParentStack.head, + _, + EdgeTypes.AST + )) + loopVariableNames.zip(loopVariableNodes).foreach { + case (loopVariableName, loopVariableNode) => + scope.addVariableReference(loopVariableName, loopVariableNode) + } + + // while loop: + val whileLoopNode = createControlStructureNode(forInOfStmt, ControlStructureTypes.WHILE) + + // while loop test: + val testCallNode = + callNode( + forInOfStmt, + s"!($resultName = $iteratorName.next()).done", + Operators.not, + DispatchTypes.STATIC_DISPATCH + ) - val thisNextNode = identifierNode(forInOfStmt, iteratorName) + val doneBaseNode = callNode( + forInOfStmt, + s"($resultName = $iteratorName.next())", + Operators.assignment, + DispatchTypes.STATIC_DISPATCH + ) - val rhsArgs = List(Ast(thisNextNode)) - val rhsAst = callAst(rhsNode, rhsArgs, receiver = Option(nextReceiverNode)) + val lhsNode = identifierNode(forInOfStmt, resultName) - val doneBaseArgs = List(Ast(lhsNode), rhsAst) - val doneBaseAst = callAst(doneBaseNode, doneBaseArgs) - Ast.storeInDiffGraph(doneBaseAst, diffGraph) + val rhsNode = + callNode(forInOfStmt, s"$iteratorName.next()", "next", DispatchTypes.DYNAMIC_DISPATCH) - val doneMemberNode = - createFieldIdentifierNode("done", forInOfStmt.lineNumber, forInOfStmt.columnNumber) + val nextBaseNode = identifierNode(forInOfStmt, iteratorName) - val testNode = - createFieldAccessCallAst( - doneBaseNode, - doneMemberNode, - forInOfStmt.lineNumber, - forInOfStmt.columnNumber - ) + val nextMemberNode = + createFieldIdentifierNode("next", forInOfStmt.lineNumber, forInOfStmt.columnNumber) - val testCallArgs = List(testNode) - val testCallAst = callAst(testCallNode, testCallArgs) + val nextReceiverNode = + createFieldAccessCallAst( + nextBaseNode, + nextMemberNode, + forInOfStmt.lineNumber, + forInOfStmt.columnNumber + ) - val whileLoopAst = - Ast(whileLoopNode).withChild(testCallAst).withConditionEdge(whileLoopNode, testCallNode) + val thisNextNode = identifierNode(forInOfStmt, iteratorName) - // while loop variable assignment: - val whileLoopVariableNode = astForNode(idNodeInfo.json) + val rhsArgs = List(Ast(thisNextNode)) + val rhsAst = callAst(rhsNode, rhsArgs, receiver = Option(nextReceiverNode)) - val baseNode = identifierNode(forInOfStmt, resultName) + val doneBaseArgs = List(Ast(lhsNode), rhsAst) + val doneBaseAst = callAst(doneBaseNode, doneBaseArgs) + Ast.storeInDiffGraph(doneBaseAst, diffGraph) - val memberNode = - createFieldIdentifierNode("value", forInOfStmt.lineNumber, forInOfStmt.columnNumber) + val doneMemberNode = + createFieldIdentifierNode("done", forInOfStmt.lineNumber, forInOfStmt.columnNumber) - val accessAst = createFieldAccessCallAst( - baseNode, - memberNode, + val testNode = + createFieldAccessCallAst( + doneBaseNode, + doneMemberNode, forInOfStmt.lineNumber, forInOfStmt.columnNumber ) - val loopVariableAssignmentNode = callNode( + val testCallArgs = List(testNode) + val testCallAst = callAst(testCallNode, testCallArgs) + + val whileLoopAst = + Ast(whileLoopNode).withChild(testCallAst).withConditionEdge(whileLoopNode, testCallNode) + + // while loop variable assignment: + val loopVariableAssignmentAsts = loopVariableNames.map { loopVariableName => + val whileLoopVariableNode = identifierNode(forInOfStmt, loopVariableName) + val baseNode = identifierNode(forInOfStmt, resultName) + val memberNode = + createFieldIdentifierNode("value", forInOfStmt.lineNumber, forInOfStmt.columnNumber) + val accessAst = createFieldAccessCallAst( + baseNode, + memberNode, + forInOfStmt.lineNumber, + forInOfStmt.columnNumber + ) + val variableMemberNode = + createFieldIdentifierNode( + loopVariableName, + forInOfStmt.lineNumber, + forInOfStmt.columnNumber + ) + val variableAccessAst = + createFieldAccessCallAst( + accessAst, + variableMemberNode, + forInOfStmt.lineNumber, + forInOfStmt.columnNumber + ) + val loopVariableAssignmentNode = callNode( + forInOfStmt, + s"$loopVariableName = $resultName.value.$loopVariableName", + Operators.assignment, + DispatchTypes.STATIC_DISPATCH + ) + val loopVariableAssignmentArgs = List(Ast(whileLoopVariableNode), variableAccessAst) + callAst(loopVariableAssignmentNode, loopVariableAssignmentArgs) + } + + val whileLoopBlockNode = createBlockNode(forInOfStmt) + scope.pushNewBlockScope(whileLoopBlockNode) + localAstParentStack.push(whileLoopBlockNode) + + // while loop block: + val bodyAst = astForNodeWithFunctionReference(forInOfStmt.json("body")) + + val whileLoopBlockChildren = loopVariableAssignmentAsts :+ bodyAst + setArgumentIndices(whileLoopBlockChildren) + val whileLoopBlockAst = blockAst(whileLoopBlockNode, whileLoopBlockChildren) + + scope.popScope() + localAstParentStack.pop() + + // end surrounding block: + scope.popScope() + localAstParentStack.pop() + + val blockNodeChildren = + List(iteratorAssignmentAst, Ast(resultNode)) ++ loopVariableNodes.map( + Ast(_) + ) :+ whileLoopAst.withChild( + whileLoopBlockAst + ) + setArgumentIndices(blockNodeChildren) + blockAst(blockNode, blockNodeChildren) + end astForInOfStatementWithObject + + /** De-sugaring from: + * + * for(var [a, b, c] of arr) { body } + * + * to: + * + * { var _iterator = .iterator(arr); var _result; var a; var b; var c; while (!(_result + * \= _iterator.next()).done) { a = _result.value[0]; b = _result.value[1]; c = _result.value[2]; + * body } } + */ + private def astForInOfStatementWithArray( + forInOfStmt: BabelNodeInfo, + idNodeInfo: BabelNodeInfo + ): Ast = + // surrounding block: + val blockNode = createBlockNode(forInOfStmt) + scope.pushNewBlockScope(blockNode) + localAstParentStack.push(blockNode) + + val collection = forInOfStmt.json("right") + val collectionName = code(collection) + + // _iterator assignment: + val iteratorName = generateUnusedVariableName(usedVariableNames, "_iterator") + val iteratorLocalNode = newLocalNode(iteratorName, Defines.Any).order(0) + val iteratorNode = identifierNode(forInOfStmt, iteratorName) + diffGraph.addEdge(localAstParentStack.head, iteratorLocalNode, EdgeTypes.AST) + scope.addVariableReference(iteratorName, iteratorNode) + + val iteratorCall = callNode( + forInOfStmt, + s".iterator($collectionName)", + ".iterator", + DispatchTypes.STATIC_DISPATCH + ) + + val objectKeysCallArgs = List(astForNodeWithFunctionReference(collection)) + val objectKeysCallAst = callAst(iteratorCall, objectKeysCallArgs) + + val iteratorAssignmentNode = + callNode( forInOfStmt, - s"${idNodeInfo.code} = $resultName.value", + s"$iteratorName = .iterator($collectionName)", Operators.assignment, DispatchTypes.STATIC_DISPATCH ) - val loopVariableAssignmentArgs = List(whileLoopVariableNode, accessAst) - val loopVariableAssignmentAst = - callAst(loopVariableAssignmentNode, loopVariableAssignmentArgs) - - val whileLoopBlockNode = createBlockNode(forInOfStmt) - scope.pushNewBlockScope(whileLoopBlockNode) - localAstParentStack.push(whileLoopBlockNode) - - // while loop block: - val bodyAst = astForNodeWithFunctionReference(forInOfStmt.json("body")) - - val whileLoopBlockChildren = List(loopVariableAssignmentAst, bodyAst) - setArgumentIndices(whileLoopBlockChildren) - val whileLoopBlockAst = blockAst(whileLoopBlockNode, whileLoopBlockChildren) - - scope.popScope() - localAstParentStack.pop() - - // end surrounding block: - scope.popScope() - localAstParentStack.pop() - - val blockChildren = - List(iteratorAssignmentAst, Ast(resultNode), whileLoopAst.withChild(whileLoopBlockAst)) - setArgumentIndices(blockChildren) - blockAst(blockNode, blockChildren) - end astForInOfStatementWithExpression - - /** De-sugaring from: - * - * for(var {a, b, c} of obj) { body } - * - * to: - * - * { var _iterator = .iterator(obj); var _result; var a; var b; var c; while - * (!(_result = _iterator.next()).done) { a = _result.value.a; b = _result.value.b; c = - * _result.value.c; body } } - */ - private def astForInOfStatementWithObject( - forInOfStmt: BabelNodeInfo, - idNodeInfo: BabelNodeInfo - ): Ast = - // surrounding block: - val blockNode = createBlockNode(forInOfStmt) - scope.pushNewBlockScope(blockNode) - localAstParentStack.push(blockNode) - - val collection = forInOfStmt.json("right") - val collectionName = code(collection) - - // _iterator assignment: - val iteratorName = generateUnusedVariableName(usedVariableNames, "_iterator") - val iteratorLocalNode = newLocalNode(iteratorName, Defines.Any).order(0) - val iteratorNode = identifierNode(forInOfStmt, iteratorName) - diffGraph.addEdge(localAstParentStack.head, iteratorLocalNode, EdgeTypes.AST) - scope.addVariableReference(iteratorName, iteratorNode) - // TODO: add operator to schema - val iteratorCall = - callNode( - forInOfStmt, - s".iterator($collectionName)", - ".iterator", - DispatchTypes.STATIC_DISPATCH - ) - - val objectKeysCallArgs = List(astForNodeWithFunctionReference(collection)) - val objectKeysCallAst = callAst(iteratorCall, objectKeysCallArgs) - - val iteratorAssignmentNode = - callNode( - forInOfStmt, - s"$iteratorName = .iterator($collectionName)", - Operators.assignment, - DispatchTypes.STATIC_DISPATCH - ) - - val iteratorAssignmentArgs = List(Ast(iteratorNode), objectKeysCallAst) - val iteratorAssignmentAst = callAst(iteratorAssignmentNode, iteratorAssignmentArgs) - - // _result: - val resultName = generateUnusedVariableName(usedVariableNames, "_result") - val resultLocalNode = newLocalNode(resultName, Defines.Any).order(0) - val resultNode = identifierNode(forInOfStmt, resultName) - diffGraph.addEdge(localAstParentStack.head, resultLocalNode, EdgeTypes.AST) - scope.addVariableReference(resultName, resultNode) - - // loop variable: - val loopVariableNames = idNodeInfo.json("properties").arr.toList.map(code) - - val loopVariableLocalNodes = loopVariableNames.map(newLocalNode(_, Defines.Any).order(0)) - val loopVariableNodes = loopVariableNames.map(identifierNode(forInOfStmt, _)) - loopVariableLocalNodes.foreach(diffGraph.addEdge( - localAstParentStack.head, - _, - EdgeTypes.AST - )) - loopVariableNames.zip(loopVariableNodes).foreach { - case (loopVariableName, loopVariableNode) => - scope.addVariableReference(loopVariableName, loopVariableNode) - } - - // while loop: - val whileLoopNode = createControlStructureNode(forInOfStmt, ControlStructureTypes.WHILE) - - // while loop test: - val testCallNode = - callNode( - forInOfStmt, - s"!($resultName = $iteratorName.next()).done", - Operators.not, - DispatchTypes.STATIC_DISPATCH - ) - - val doneBaseNode = callNode( + val iteratorAssignmentArgs = List(Ast(iteratorNode), objectKeysCallAst) + val iteratorAssignmentAst = callAst(iteratorAssignmentNode, iteratorAssignmentArgs) + + // _result: + val resultName = generateUnusedVariableName(usedVariableNames, "_result") + val resultLocalNode = newLocalNode(resultName, Defines.Any).order(0) + val resultNode = identifierNode(forInOfStmt, resultName) + diffGraph.addEdge(localAstParentStack.head, resultLocalNode, EdgeTypes.AST) + scope.addVariableReference(resultName, resultNode) + + // loop variable: + val loopVariableNames = idNodeInfo.json("elements").arr.toList.map(code) + + val loopVariableLocalNodes = loopVariableNames.map(newLocalNode(_, Defines.Any).order(0)) + val loopVariableNodes = loopVariableNames.map(identifierNode(forInOfStmt, _)) + loopVariableLocalNodes.foreach(diffGraph.addEdge( + localAstParentStack.head, + _, + EdgeTypes.AST + )) + loopVariableNames.zip(loopVariableNodes).foreach { + case (loopVariableName, loopVariableNode) => + scope.addVariableReference(loopVariableName, loopVariableNode) + } + + // while loop: + val whileLoopNode = createControlStructureNode(forInOfStmt, ControlStructureTypes.WHILE) + + // while loop test: + val testCallNode = + callNode( forInOfStmt, - s"($resultName = $iteratorName.next())", - Operators.assignment, + s"!($resultName = $iteratorName.next()).done", + Operators.not, DispatchTypes.STATIC_DISPATCH ) - val lhsNode = identifierNode(forInOfStmt, resultName) + val doneBaseNode = callNode( + forInOfStmt, + s"($resultName = $iteratorName.next())", + Operators.assignment, + DispatchTypes.STATIC_DISPATCH + ) - val rhsNode = - callNode(forInOfStmt, s"$iteratorName.next()", "next", DispatchTypes.DYNAMIC_DISPATCH) + val lhsNode = identifierNode(forInOfStmt, resultName) - val nextBaseNode = identifierNode(forInOfStmt, iteratorName) + val rhsNode = + callNode(forInOfStmt, s"$iteratorName.next()", "next", DispatchTypes.DYNAMIC_DISPATCH) - val nextMemberNode = - createFieldIdentifierNode("next", forInOfStmt.lineNumber, forInOfStmt.columnNumber) + val nextBaseNode = identifierNode(forInOfStmt, iteratorName) - val nextReceiverNode = - createFieldAccessCallAst( - nextBaseNode, - nextMemberNode, - forInOfStmt.lineNumber, - forInOfStmt.columnNumber - ) + val nextMemberNode = + createFieldIdentifierNode("next", forInOfStmt.lineNumber, forInOfStmt.columnNumber) - val thisNextNode = identifierNode(forInOfStmt, iteratorName) + val nextReceiverNode = + createFieldAccessCallAst( + nextBaseNode, + nextMemberNode, + forInOfStmt.lineNumber, + forInOfStmt.columnNumber + ) - val rhsArgs = List(Ast(thisNextNode)) - val rhsAst = callAst(rhsNode, rhsArgs, receiver = Option(nextReceiverNode)) + val thisNextNode = identifierNode(forInOfStmt, iteratorName) - val doneBaseArgs = List(Ast(lhsNode), rhsAst) - val doneBaseAst = callAst(doneBaseNode, doneBaseArgs) - Ast.storeInDiffGraph(doneBaseAst, diffGraph) + val rhsArgs = List(Ast(thisNextNode)) + val rhsAst = callAst(rhsNode, rhsArgs, receiver = Option(nextReceiverNode)) - val doneMemberNode = - createFieldIdentifierNode("done", forInOfStmt.lineNumber, forInOfStmt.columnNumber) + val doneBaseArgs = List(Ast(lhsNode), rhsAst) + val doneBaseAst = callAst(doneBaseNode, doneBaseArgs) + Ast.storeInDiffGraph(doneBaseAst, diffGraph) - val testNode = - createFieldAccessCallAst( - doneBaseNode, - doneMemberNode, - forInOfStmt.lineNumber, - forInOfStmt.columnNumber - ) + val doneMemberNode = + createFieldIdentifierNode("done", forInOfStmt.lineNumber, forInOfStmt.columnNumber) + + val testNode = + createFieldAccessCallAst( + doneBaseNode, + doneMemberNode, + forInOfStmt.lineNumber, + forInOfStmt.columnNumber + ) - val testCallArgs = List(testNode) - val testCallAst = callAst(testCallNode, testCallArgs) + val testCallArgs = List(testNode) + val testCallAst = callAst(testCallNode, testCallArgs) - val whileLoopAst = - Ast(whileLoopNode).withChild(testCallAst).withConditionEdge(whileLoopNode, testCallNode) + val whileLoopAst = + Ast(whileLoopNode).withChild(testCallAst).withConditionEdge(whileLoopNode, testCallNode) - // while loop variable assignment: - val loopVariableAssignmentAsts = loopVariableNames.map { loopVariableName => + // while loop variable assignment: + val loopVariableAssignmentAsts = + loopVariableNames.zipWithIndex.map { case (loopVariableName, index) => val whileLoopVariableNode = identifierNode(forInOfStmt, loopVariableName) val baseNode = identifierNode(forInOfStmt, resultName) - val memberNode = - createFieldIdentifierNode("value", forInOfStmt.lineNumber, forInOfStmt.columnNumber) + val memberNode = createFieldIdentifierNode( + "value", + forInOfStmt.lineNumber, + forInOfStmt.columnNumber + ) val accessAst = createFieldAccessCallAst( baseNode, memberNode, forInOfStmt.lineNumber, forInOfStmt.columnNumber ) - val variableMemberNode = - createFieldIdentifierNode( - loopVariableName, - forInOfStmt.lineNumber, - forInOfStmt.columnNumber - ) + val variableMemberNode = literalNode( + forInOfStmt, + index.toString, + dynamicTypeOption = Some(Defines.Number) + ) val variableAccessAst = - createFieldAccessCallAst( + createIndexAccessCallAst( accessAst, - variableMemberNode, + Ast(variableMemberNode), forInOfStmt.lineNumber, forInOfStmt.columnNumber ) val loopVariableAssignmentNode = callNode( forInOfStmt, - s"$loopVariableName = $resultName.value.$loopVariableName", + s"$loopVariableName = $resultName.value[$index]", Operators.assignment, DispatchTypes.STATIC_DISPATCH ) @@ -816,265 +1017,64 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode): callAst(loopVariableAssignmentNode, loopVariableAssignmentArgs) } - val whileLoopBlockNode = createBlockNode(forInOfStmt) - scope.pushNewBlockScope(whileLoopBlockNode) - localAstParentStack.push(whileLoopBlockNode) + val whileLoopBlockNode = createBlockNode(forInOfStmt) + scope.pushNewBlockScope(whileLoopBlockNode) + localAstParentStack.push(whileLoopBlockNode) - // while loop block: - val bodyAst = astForNodeWithFunctionReference(forInOfStmt.json("body")) + // while loop block: + val bodyAst = astForNodeWithFunctionReference(forInOfStmt.json("body")) - val whileLoopBlockChildren = loopVariableAssignmentAsts :+ bodyAst - setArgumentIndices(whileLoopBlockChildren) - val whileLoopBlockAst = blockAst(whileLoopBlockNode, whileLoopBlockChildren) + val whileLoopBlockChildren = loopVariableAssignmentAsts :+ bodyAst + setArgumentIndices(whileLoopBlockChildren) + val whileLoopBlockAst = blockAst(whileLoopBlockNode, whileLoopBlockChildren) - scope.popScope() - localAstParentStack.pop() + scope.popScope() + localAstParentStack.pop() - // end surrounding block: - scope.popScope() - localAstParentStack.pop() + // end surrounding block: + scope.popScope() + localAstParentStack.pop() - val blockNodeChildren = - List(iteratorAssignmentAst, Ast(resultNode)) ++ loopVariableNodes.map( - Ast(_) - ) :+ whileLoopAst.withChild( - whileLoopBlockAst - ) - setArgumentIndices(blockNodeChildren) - blockAst(blockNode, blockNodeChildren) - end astForInOfStatementWithObject - - /** De-sugaring from: - * - * for(var [a, b, c] of arr) { body } - * - * to: - * - * { var _iterator = .iterator(arr); var _result; var a; var b; var c; while - * (!(_result = _iterator.next()).done) { a = _result.value[0]; b = _result.value[1]; c = - * _result.value[2]; body } } - */ - private def astForInOfStatementWithArray( - forInOfStmt: BabelNodeInfo, - idNodeInfo: BabelNodeInfo - ): Ast = - // surrounding block: - val blockNode = createBlockNode(forInOfStmt) - scope.pushNewBlockScope(blockNode) - localAstParentStack.push(blockNode) - - val collection = forInOfStmt.json("right") - val collectionName = code(collection) - - // _iterator assignment: - val iteratorName = generateUnusedVariableName(usedVariableNames, "_iterator") - val iteratorLocalNode = newLocalNode(iteratorName, Defines.Any).order(0) - val iteratorNode = identifierNode(forInOfStmt, iteratorName) - diffGraph.addEdge(localAstParentStack.head, iteratorLocalNode, EdgeTypes.AST) - scope.addVariableReference(iteratorName, iteratorNode) - - val iteratorCall = callNode( - forInOfStmt, - s".iterator($collectionName)", - ".iterator", - DispatchTypes.STATIC_DISPATCH - ) - - val objectKeysCallArgs = List(astForNodeWithFunctionReference(collection)) - val objectKeysCallAst = callAst(iteratorCall, objectKeysCallArgs) - - val iteratorAssignmentNode = - callNode( - forInOfStmt, - s"$iteratorName = .iterator($collectionName)", - Operators.assignment, - DispatchTypes.STATIC_DISPATCH - ) - - val iteratorAssignmentArgs = List(Ast(iteratorNode), objectKeysCallAst) - val iteratorAssignmentAst = callAst(iteratorAssignmentNode, iteratorAssignmentArgs) - - // _result: - val resultName = generateUnusedVariableName(usedVariableNames, "_result") - val resultLocalNode = newLocalNode(resultName, Defines.Any).order(0) - val resultNode = identifierNode(forInOfStmt, resultName) - diffGraph.addEdge(localAstParentStack.head, resultLocalNode, EdgeTypes.AST) - scope.addVariableReference(resultName, resultNode) - - // loop variable: - val loopVariableNames = idNodeInfo.json("elements").arr.toList.map(code) - - val loopVariableLocalNodes = loopVariableNames.map(newLocalNode(_, Defines.Any).order(0)) - val loopVariableNodes = loopVariableNames.map(identifierNode(forInOfStmt, _)) - loopVariableLocalNodes.foreach(diffGraph.addEdge( - localAstParentStack.head, - _, - EdgeTypes.AST - )) - loopVariableNames.zip(loopVariableNodes).foreach { - case (loopVariableName, loopVariableNode) => - scope.addVariableReference(loopVariableName, loopVariableNode) - } - - // while loop: - val whileLoopNode = createControlStructureNode(forInOfStmt, ControlStructureTypes.WHILE) - - // while loop test: - val testCallNode = - callNode( - forInOfStmt, - s"!($resultName = $iteratorName.next()).done", - Operators.not, - DispatchTypes.STATIC_DISPATCH - ) - - val doneBaseNode = callNode( - forInOfStmt, - s"($resultName = $iteratorName.next())", - Operators.assignment, - DispatchTypes.STATIC_DISPATCH + val blockNodeChildren = + List(iteratorAssignmentAst, Ast(resultNode)) ++ loopVariableNodes.map( + Ast(_) + ) :+ whileLoopAst.withChild( + whileLoopBlockAst ) - - val lhsNode = identifierNode(forInOfStmt, resultName) - - val rhsNode = - callNode(forInOfStmt, s"$iteratorName.next()", "next", DispatchTypes.DYNAMIC_DISPATCH) - - val nextBaseNode = identifierNode(forInOfStmt, iteratorName) - - val nextMemberNode = - createFieldIdentifierNode("next", forInOfStmt.lineNumber, forInOfStmt.columnNumber) - - val nextReceiverNode = - createFieldAccessCallAst( - nextBaseNode, - nextMemberNode, - forInOfStmt.lineNumber, - forInOfStmt.columnNumber - ) - - val thisNextNode = identifierNode(forInOfStmt, iteratorName) - - val rhsArgs = List(Ast(thisNextNode)) - val rhsAst = callAst(rhsNode, rhsArgs, receiver = Option(nextReceiverNode)) - - val doneBaseArgs = List(Ast(lhsNode), rhsAst) - val doneBaseAst = callAst(doneBaseNode, doneBaseArgs) - Ast.storeInDiffGraph(doneBaseAst, diffGraph) - - val doneMemberNode = - createFieldIdentifierNode("done", forInOfStmt.lineNumber, forInOfStmt.columnNumber) - - val testNode = - createFieldAccessCallAst( - doneBaseNode, - doneMemberNode, - forInOfStmt.lineNumber, - forInOfStmt.columnNumber - ) - - val testCallArgs = List(testNode) - val testCallAst = callAst(testCallNode, testCallArgs) - - val whileLoopAst = - Ast(whileLoopNode).withChild(testCallAst).withConditionEdge(whileLoopNode, testCallNode) - - // while loop variable assignment: - val loopVariableAssignmentAsts = - loopVariableNames.zipWithIndex.map { case (loopVariableName, index) => - val whileLoopVariableNode = identifierNode(forInOfStmt, loopVariableName) - val baseNode = identifierNode(forInOfStmt, resultName) - val memberNode = createFieldIdentifierNode( - "value", - forInOfStmt.lineNumber, - forInOfStmt.columnNumber - ) - val accessAst = createFieldAccessCallAst( - baseNode, - memberNode, - forInOfStmt.lineNumber, - forInOfStmt.columnNumber - ) - val variableMemberNode = literalNode( - forInOfStmt, - index.toString, - dynamicTypeOption = Some(Defines.Number) - ) - val variableAccessAst = - createIndexAccessCallAst( - accessAst, - Ast(variableMemberNode), - forInOfStmt.lineNumber, - forInOfStmt.columnNumber - ) - val loopVariableAssignmentNode = callNode( - forInOfStmt, - s"$loopVariableName = $resultName.value[$index]", - Operators.assignment, - DispatchTypes.STATIC_DISPATCH - ) - val loopVariableAssignmentArgs = List(Ast(whileLoopVariableNode), variableAccessAst) - callAst(loopVariableAssignmentNode, loopVariableAssignmentArgs) - } - - val whileLoopBlockNode = createBlockNode(forInOfStmt) - scope.pushNewBlockScope(whileLoopBlockNode) - localAstParentStack.push(whileLoopBlockNode) - - // while loop block: - val bodyAst = astForNodeWithFunctionReference(forInOfStmt.json("body")) - - val whileLoopBlockChildren = loopVariableAssignmentAsts :+ bodyAst - setArgumentIndices(whileLoopBlockChildren) - val whileLoopBlockAst = blockAst(whileLoopBlockNode, whileLoopBlockChildren) - - scope.popScope() - localAstParentStack.pop() - - // end surrounding block: - scope.popScope() - localAstParentStack.pop() - - val blockNodeChildren = - List(iteratorAssignmentAst, Ast(resultNode)) ++ loopVariableNodes.map( - Ast(_) - ) :+ whileLoopAst.withChild( - whileLoopBlockAst - ) - setArgumentIndices(blockNodeChildren) - blockAst(blockNode, blockNodeChildren) - end astForInOfStatementWithArray - - private def extractLoopVariableNodeInfo(nodeInfo: BabelNodeInfo): Option[BabelNodeInfo] = - nodeInfo.node match - case AssignmentPattern => - Option(createBabelNodeInfo(nodeInfo.json("left"))) - case VariableDeclaration => - val varDeclNodeInfo = createBabelNodeInfo(nodeInfo.json("declarations").arr.head) - if varDeclNodeInfo.node == VariableDeclarator then - Option(createBabelNodeInfo(varDeclNodeInfo.json("id"))) - else None - case _ => None - - protected def astForInOfStatement(forInOfStmt: BabelNodeInfo): Ast = - val loopVariableNodeInfo = createBabelNodeInfo(forInOfStmt.json("left")) - // check iteration loop variable type: - loopVariableNodeInfo.node match - case VariableDeclaration | AssignmentPattern => - val idNodeInfo = extractLoopVariableNodeInfo(loopVariableNodeInfo) - idNodeInfo.map(_.node) match - case Some(ObjectPattern) => - astForInOfStatementWithObject(forInOfStmt, idNodeInfo.get) - case Some(ArrayPattern) => - astForInOfStatementWithArray(forInOfStmt, idNodeInfo.get) - case Some(Identifier) => - astForInOfStatementWithIdentifier(forInOfStmt, idNodeInfo.get) - case _ => notHandledYet(forInOfStmt) - case ObjectPattern => astForInOfStatementWithObject(forInOfStmt, loopVariableNodeInfo) - case ArrayPattern => astForInOfStatementWithArray(forInOfStmt, loopVariableNodeInfo) - case Identifier => astForInOfStatementWithIdentifier(forInOfStmt, loopVariableNodeInfo) - case _: Expression => - astForInOfStatementWithExpression(forInOfStmt, loopVariableNodeInfo) + setArgumentIndices(blockNodeChildren) + blockAst(blockNode, blockNodeChildren) + end astForInOfStatementWithArray + + private def extractLoopVariableNodeInfo(nodeInfo: BabelNodeInfo): Option[BabelNodeInfo] = + nodeInfo.node match + case AssignmentPattern => + Option(createBabelNodeInfo(nodeInfo.json("left"))) + case VariableDeclaration => + val varDeclNodeInfo = createBabelNodeInfo(nodeInfo.json("declarations").arr.head) + if varDeclNodeInfo.node == VariableDeclarator then + Option(createBabelNodeInfo(varDeclNodeInfo.json("id"))) + else None + case _ => None + + protected def astForInOfStatement(forInOfStmt: BabelNodeInfo): Ast = + val loopVariableNodeInfo = createBabelNodeInfo(forInOfStmt.json("left")) + // check iteration loop variable type: + loopVariableNodeInfo.node match + case VariableDeclaration | AssignmentPattern => + val idNodeInfo = extractLoopVariableNodeInfo(loopVariableNodeInfo) + idNodeInfo.map(_.node) match + case Some(ObjectPattern) => + astForInOfStatementWithObject(forInOfStmt, idNodeInfo.get) + case Some(ArrayPattern) => + astForInOfStatementWithArray(forInOfStmt, idNodeInfo.get) + case Some(Identifier) => + astForInOfStatementWithIdentifier(forInOfStmt, idNodeInfo.get) case _ => notHandledYet(forInOfStmt) - end astForInOfStatement + case ObjectPattern => astForInOfStatementWithObject(forInOfStmt, loopVariableNodeInfo) + case ArrayPattern => astForInOfStatementWithArray(forInOfStmt, loopVariableNodeInfo) + case Identifier => astForInOfStatementWithIdentifier(forInOfStmt, loopVariableNodeInfo) + case _: Expression => + astForInOfStatementWithExpression(forInOfStmt, loopVariableNodeInfo) + case _ => notHandledYet(forInOfStmt) + end astForInOfStatement end AstForStatementsCreator diff --git a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/astcreation/AstForTemplateDomCreator.scala b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/astcreation/AstForTemplateDomCreator.scala index 320a811f..4532c843 100644 --- a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/astcreation/AstForTemplateDomCreator.scala +++ b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/astcreation/AstForTemplateDomCreator.scala @@ -6,114 +6,114 @@ import io.appthreat.x2cpg.{Ast, ValidationMode} import ujson.Obj trait AstForTemplateDomCreator(implicit withSchemaValidation: ValidationMode): - this: AstCreator => + this: AstCreator => - protected def astForJsxElement(jsxElem: BabelNodeInfo): Ast = - val domNode = createTemplateDomNode( - jsxElem.node.toString, - jsxElem.code, - jsxElem.lineNumber, - jsxElem.columnNumber - ) - val openingAst = astForNodeWithFunctionReference(jsxElem.json("openingElement")) - val childrenAsts = astForNodes(jsxElem.json("children").arr.toList) - val closingAst = - safeObj(jsxElem.json, "closingElement") - .map(e => astForNodeWithFunctionReference(Obj(e))) - .getOrElse(Ast()) - val allChildrenAsts = openingAst +: childrenAsts :+ closingAst - setArgumentIndices(allChildrenAsts) - Ast(domNode).withChildren(allChildrenAsts) - - protected def astForJsxFragment(jsxFragment: BabelNodeInfo): Ast = - val domNode = createTemplateDomNode( - jsxFragment.node.toString, - jsxFragment.code, - jsxFragment.lineNumber, - jsxFragment.columnNumber - ) - val childrenAsts = astForNodes(jsxFragment.json("children").arr.toList) - setArgumentIndices(childrenAsts) - Ast(domNode).withChildren(childrenAsts) - - protected def astForJsxAttribute(jsxAttr: BabelNodeInfo): Ast = - // A colon in front of a JSXAttribute cant be parsed by Babel. - // Hence, we strip it away with astgen and restore it here. - // parserResult.fileContent contains the unmodified Vue.js source code for the current file. - // We look at the previous character there and re-add the colon if needed. - val colon = pos(jsxAttr.json) - .collect { - case position - if position > 0 && parserResult.fileContent.substring( - position - 1, - position - ) == ":" => ":" - } - .getOrElse("") - val domNode = - createTemplateDomNode( - jsxAttr.node.toString, - s"$colon${jsxAttr.code}", - jsxAttr.lineNumber, - jsxAttr.columnNumber.map(_ - colon.length) - ) - val valueAst = safeObj(jsxAttr.json, "value") + protected def astForJsxElement(jsxElem: BabelNodeInfo): Ast = + val domNode = createTemplateDomNode( + jsxElem.node.toString, + jsxElem.code, + jsxElem.lineNumber, + jsxElem.columnNumber + ) + val openingAst = astForNodeWithFunctionReference(jsxElem.json("openingElement")) + val childrenAsts = astForNodes(jsxElem.json("children").arr.toList) + val closingAst = + safeObj(jsxElem.json, "closingElement") .map(e => astForNodeWithFunctionReference(Obj(e))) .getOrElse(Ast()) - setArgumentIndices(List(valueAst)) - Ast(domNode).withChild(valueAst) - end astForJsxAttribute + val allChildrenAsts = openingAst +: childrenAsts :+ closingAst + setArgumentIndices(allChildrenAsts) + Ast(domNode).withChildren(allChildrenAsts) - protected def astForJsxOpeningElement(jsxOpeningElem: BabelNodeInfo): Ast = - val domNode = createTemplateDomNode( - jsxOpeningElem.node.toString, - jsxOpeningElem.code, - jsxOpeningElem.lineNumber, - jsxOpeningElem.columnNumber - ) - val childrenAsts = astForNodes(jsxOpeningElem.json("attributes").arr.toList) - setArgumentIndices(childrenAsts) - Ast(domNode).withChildren(childrenAsts) + protected def astForJsxFragment(jsxFragment: BabelNodeInfo): Ast = + val domNode = createTemplateDomNode( + jsxFragment.node.toString, + jsxFragment.code, + jsxFragment.lineNumber, + jsxFragment.columnNumber + ) + val childrenAsts = astForNodes(jsxFragment.json("children").arr.toList) + setArgumentIndices(childrenAsts) + Ast(domNode).withChildren(childrenAsts) - protected def astForJsxClosingElement(jsxClosingElem: BabelNodeInfo): Ast = - val domNode = createTemplateDomNode( - jsxClosingElem.node.toString, - jsxClosingElem.code, - jsxClosingElem.lineNumber, - jsxClosingElem.columnNumber + protected def astForJsxAttribute(jsxAttr: BabelNodeInfo): Ast = + // A colon in front of a JSXAttribute cant be parsed by Babel. + // Hence, we strip it away with astgen and restore it here. + // parserResult.fileContent contains the unmodified Vue.js source code for the current file. + // We look at the previous character there and re-add the colon if needed. + val colon = pos(jsxAttr.json) + .collect { + case position + if position > 0 && parserResult.fileContent.substring( + position - 1, + position + ) == ":" => ":" + } + .getOrElse("") + val domNode = + createTemplateDomNode( + jsxAttr.node.toString, + s"$colon${jsxAttr.code}", + jsxAttr.lineNumber, + jsxAttr.columnNumber.map(_ - colon.length) ) - Ast(domNode) + val valueAst = safeObj(jsxAttr.json, "value") + .map(e => astForNodeWithFunctionReference(Obj(e))) + .getOrElse(Ast()) + setArgumentIndices(List(valueAst)) + Ast(domNode).withChild(valueAst) + end astForJsxAttribute - protected def astForJsxText(jsxText: BabelNodeInfo): Ast = - Ast(createTemplateDomNode( - jsxText.node.toString, - jsxText.code, - jsxText.lineNumber, - jsxText.columnNumber - )) + protected def astForJsxOpeningElement(jsxOpeningElem: BabelNodeInfo): Ast = + val domNode = createTemplateDomNode( + jsxOpeningElem.node.toString, + jsxOpeningElem.code, + jsxOpeningElem.lineNumber, + jsxOpeningElem.columnNumber + ) + val childrenAsts = astForNodes(jsxOpeningElem.json("attributes").arr.toList) + setArgumentIndices(childrenAsts) + Ast(domNode).withChildren(childrenAsts) - protected def astForJsxExprContainer(jsxExprContainer: BabelNodeInfo): Ast = - val domNode = createTemplateDomNode( - jsxExprContainer.node.toString, - jsxExprContainer.code, - jsxExprContainer.lineNumber, - jsxExprContainer.columnNumber - ) - val nodeInfo = createBabelNodeInfo(jsxExprContainer.json("expression")) - val exprAst = nodeInfo.node match - case JSXEmptyExpression => Ast() - case _ => astForNodeWithFunctionReference(nodeInfo.json) - setArgumentIndices(List(exprAst)) - Ast(domNode).withChild(exprAst) + protected def astForJsxClosingElement(jsxClosingElem: BabelNodeInfo): Ast = + val domNode = createTemplateDomNode( + jsxClosingElem.node.toString, + jsxClosingElem.code, + jsxClosingElem.lineNumber, + jsxClosingElem.columnNumber + ) + Ast(domNode) - protected def astForJsxSpreadAttribute(jsxSpreadAttr: BabelNodeInfo): Ast = - val domNode = createTemplateDomNode( - jsxSpreadAttr.node.toString, - jsxSpreadAttr.code, - jsxSpreadAttr.lineNumber, - jsxSpreadAttr.columnNumber - ) - val argAst = astForNodeWithFunctionReference(jsxSpreadAttr.json("argument")) - setArgumentIndices(List(argAst)) - Ast(domNode).withChild(argAst) + protected def astForJsxText(jsxText: BabelNodeInfo): Ast = + Ast(createTemplateDomNode( + jsxText.node.toString, + jsxText.code, + jsxText.lineNumber, + jsxText.columnNumber + )) + + protected def astForJsxExprContainer(jsxExprContainer: BabelNodeInfo): Ast = + val domNode = createTemplateDomNode( + jsxExprContainer.node.toString, + jsxExprContainer.code, + jsxExprContainer.lineNumber, + jsxExprContainer.columnNumber + ) + val nodeInfo = createBabelNodeInfo(jsxExprContainer.json("expression")) + val exprAst = nodeInfo.node match + case JSXEmptyExpression => Ast() + case _ => astForNodeWithFunctionReference(nodeInfo.json) + setArgumentIndices(List(exprAst)) + Ast(domNode).withChild(exprAst) + + protected def astForJsxSpreadAttribute(jsxSpreadAttr: BabelNodeInfo): Ast = + val domNode = createTemplateDomNode( + jsxSpreadAttr.node.toString, + jsxSpreadAttr.code, + jsxSpreadAttr.lineNumber, + jsxSpreadAttr.columnNumber + ) + val argAst = astForNodeWithFunctionReference(jsxSpreadAttr.json("argument")) + setArgumentIndices(List(argAst)) + Ast(domNode).withChild(argAst) end AstForTemplateDomCreator diff --git a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/astcreation/AstForTypesCreator.scala b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/astcreation/AstForTypesCreator.scala index 8847705e..d80c77f1 100644 --- a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/astcreation/AstForTypesCreator.scala +++ b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/astcreation/AstForTypesCreator.scala @@ -14,94 +14,94 @@ import ujson.Value import scala.util.Try trait AstForTypesCreator(implicit withSchemaValidation: ValidationMode): - this: AstCreator => - - protected def astForTypeAlias(alias: BabelNodeInfo): Ast = - val (aliasName, aliasFullName) = calcTypeNameAndFullName(alias) - val name = if hasKey(alias.json, "right") then - typeFor(createBabelNodeInfo(alias.json("right"))) - else - typeFor(createBabelNodeInfo(alias.json)) - registerType(aliasName, aliasFullName) - - val astParentType = methodAstParentStack.head.label - val astParentFullName = methodAstParentStack.head.properties("FULL_NAME").toString - - val aliasTypeDeclNode = - typeDeclNode( - alias, - aliasName, - aliasFullName, - parserResult.filename, - alias.code, - astParentType, - astParentFullName - ) - seenAliasTypes.add(aliasTypeDeclNode) - - val typeDeclNodeAst = - if !Defines.JsTypes.contains(name) && !seenAliasTypes.exists(_.name == name) then - val (typeName, typeFullName) = calcTypeNameAndFullName(alias, Option(name)) - val typeDeclNode_ = typeDeclNode( - alias, - typeName, - typeFullName, - parserResult.filename, - alias.code, - astParentType, - astParentFullName, - alias = Option(aliasFullName) - ) - registerType(typeName, typeFullName) - Ast(typeDeclNode_) - else - seenAliasTypes - .collectFirst { - case typeDecl if typeDecl.name == name => - Ast(typeDecl.aliasTypeFullName(aliasFullName)) - } - .getOrElse(Ast()) - - // adding all class methods / functions and uninitialized, non-static members - val membersAndInitializers = (alias.node match - case TSTypeLiteral => classMembersForTypeAlias(alias) - case ObjectPattern => Try(alias.json("properties").arr).toOption.toSeq.flatten - case _ => classMembersForTypeAlias(createBabelNodeInfo(alias.json("typeAnnotation"))) - ).filter(member => - isClassMethodOrUninitializedMemberOrObjectProperty(member) && !isStaticMember(member) + this: AstCreator => + + protected def astForTypeAlias(alias: BabelNodeInfo): Ast = + val (aliasName, aliasFullName) = calcTypeNameAndFullName(alias) + val name = if hasKey(alias.json, "right") then + typeFor(createBabelNodeInfo(alias.json("right"))) + else + typeFor(createBabelNodeInfo(alias.json)) + registerType(aliasName, aliasFullName) + + val astParentType = methodAstParentStack.head.label + val astParentFullName = methodAstParentStack.head.properties("FULL_NAME").toString + + val aliasTypeDeclNode = + typeDeclNode( + alias, + aliasName, + aliasFullName, + parserResult.filename, + alias.code, + astParentType, + astParentFullName ) - .map(m => astForClassMember(m, aliasTypeDeclNode)) - typeDeclNodeAst.root.foreach(diffGraph.addEdge(methodAstParentStack.head, _, EdgeTypes.AST)) - Ast(aliasTypeDeclNode).withChildren(membersAndInitializers) - end astForTypeAlias - - private def isConstructor(json: Value): Boolean = createBabelNodeInfo(json).node match - case TSConstructSignatureDeclaration => true - case _ => safeStr(json, "kind").contains("constructor") - - private def classMembers(clazz: BabelNodeInfo, withConstructor: Boolean = true): Seq[Value] = - val allMembers = Try(clazz.json("body")("body").arr).toOption.toSeq.flatten - val dynamicallyDeclaredMembers = - allMembers - .find(isConstructor) - .flatMap(c => Try(c("body")("body").arr).toOption) - .toSeq - .flatten - .filter(isInitializedMember) - if withConstructor then - allMembers ++ dynamicallyDeclaredMembers + seenAliasTypes.add(aliasTypeDeclNode) + + val typeDeclNodeAst = + if !Defines.JsTypes.contains(name) && !seenAliasTypes.exists(_.name == name) then + val (typeName, typeFullName) = calcTypeNameAndFullName(alias, Option(name)) + val typeDeclNode_ = typeDeclNode( + alias, + typeName, + typeFullName, + parserResult.filename, + alias.code, + astParentType, + astParentFullName, + alias = Option(aliasFullName) + ) + registerType(typeName, typeFullName) + Ast(typeDeclNode_) else - allMembers.filterNot(isConstructor) ++ dynamicallyDeclaredMembers - - private def classMembersForTypeAlias(alias: BabelNodeInfo): Seq[Value] = - Try(alias.json("members").arr).toOption.toSeq.flatten - - private def createFakeConstructor( - code: String, - forElem: BabelNodeInfo, - methodBlockContent: List[Ast] = List.empty - ): MethodAst = - val fakeConstructorCode = s"""{ + seenAliasTypes + .collectFirst { + case typeDecl if typeDecl.name == name => + Ast(typeDecl.aliasTypeFullName(aliasFullName)) + } + .getOrElse(Ast()) + + // adding all class methods / functions and uninitialized, non-static members + val membersAndInitializers = (alias.node match + case TSTypeLiteral => classMembersForTypeAlias(alias) + case ObjectPattern => Try(alias.json("properties").arr).toOption.toSeq.flatten + case _ => classMembersForTypeAlias(createBabelNodeInfo(alias.json("typeAnnotation"))) + ).filter(member => + isClassMethodOrUninitializedMemberOrObjectProperty(member) && !isStaticMember(member) + ) + .map(m => astForClassMember(m, aliasTypeDeclNode)) + typeDeclNodeAst.root.foreach(diffGraph.addEdge(methodAstParentStack.head, _, EdgeTypes.AST)) + Ast(aliasTypeDeclNode).withChildren(membersAndInitializers) + end astForTypeAlias + + private def isConstructor(json: Value): Boolean = createBabelNodeInfo(json).node match + case TSConstructSignatureDeclaration => true + case _ => safeStr(json, "kind").contains("constructor") + + private def classMembers(clazz: BabelNodeInfo, withConstructor: Boolean = true): Seq[Value] = + val allMembers = Try(clazz.json("body")("body").arr).toOption.toSeq.flatten + val dynamicallyDeclaredMembers = + allMembers + .find(isConstructor) + .flatMap(c => Try(c("body")("body").arr).toOption) + .toSeq + .flatten + .filter(isInitializedMember) + if withConstructor then + allMembers ++ dynamicallyDeclaredMembers + else + allMembers.filterNot(isConstructor) ++ dynamicallyDeclaredMembers + + private def classMembersForTypeAlias(alias: BabelNodeInfo): Seq[Value] = + Try(alias.json("members").arr).toOption.toSeq.flatten + + private def createFakeConstructor( + code: String, + forElem: BabelNodeInfo, + methodBlockContent: List[Ast] = List.empty + ): MethodAst = + val fakeConstructorCode = s"""{ | "type": "ClassMethod", | "key": { | "type": "Identifier", @@ -121,512 +121,512 @@ trait AstForTypesCreator(implicit withSchemaValidation: ValidationMode): | "body": [] | } |}""".stripMargin - val result = createMethodAstAndNode( - createBabelNodeInfo(ujson.read(fakeConstructorCode)), - methodBlockContent = methodBlockContent - ) - result.methodNode.code(code) - result - end createFakeConstructor - - private def findClassConstructor(clazz: BabelNodeInfo): Option[Value] = - classMembers(clazz).find(isConstructor) - - private def createClassConstructor( - classExpr: BabelNodeInfo, - constructorContent: List[Ast] - ): MethodAst = - findClassConstructor(classExpr) match - case Some(classConstructor) if hasKey(classConstructor, "body") => - val result = - createMethodAstAndNode( - createBabelNodeInfo(classConstructor), - methodBlockContent = constructorContent - ) - diffGraph.addEdge( - result.methodNode, - NewModifier().modifierType(ModifierTypes.CONSTRUCTOR), - EdgeTypes.AST - ) - result - case Some(classConstructor) => - val methodNode = - createMethodDefinitionNode( - createBabelNodeInfo(classConstructor), - methodBlockContent = constructorContent - ) - diffGraph.addEdge( - methodNode, - NewModifier().modifierType(ModifierTypes.CONSTRUCTOR), - EdgeTypes.AST - ) - MethodAst(Ast(methodNode), methodNode, Ast(methodNode)) - case _ => - val result = createFakeConstructor( - "constructor() {}", - classExpr, + val result = createMethodAstAndNode( + createBabelNodeInfo(ujson.read(fakeConstructorCode)), + methodBlockContent = methodBlockContent + ) + result.methodNode.code(code) + result + end createFakeConstructor + + private def findClassConstructor(clazz: BabelNodeInfo): Option[Value] = + classMembers(clazz).find(isConstructor) + + private def createClassConstructor( + classExpr: BabelNodeInfo, + constructorContent: List[Ast] + ): MethodAst = + findClassConstructor(classExpr) match + case Some(classConstructor) if hasKey(classConstructor, "body") => + val result = + createMethodAstAndNode( + createBabelNodeInfo(classConstructor), methodBlockContent = constructorContent ) - diffGraph.addEdge( - result.methodNode, - NewModifier().modifierType(ModifierTypes.CONSTRUCTOR), - EdgeTypes.AST - ) - result - - private def interfaceConstructor(typeName: String, tsInterface: BabelNodeInfo): NewMethod = - findClassConstructor(tsInterface) match - case Some(interfaceConstructor) => - createMethodDefinitionNode(createBabelNodeInfo(interfaceConstructor)) - case _ => createFakeConstructor(s"new: $typeName", tsInterface).methodNode - - private def astsForEnumMember(tsEnumMember: BabelNodeInfo): Seq[Ast] = - val name = code(tsEnumMember.json("id")) - val memberNode_ = memberNode(tsEnumMember, name, tsEnumMember.code, typeFor(tsEnumMember)) - addModifier(memberNode_, tsEnumMember.json) - - if hasKey(tsEnumMember.json, "initializer") then - val lhsAst = astForNode(tsEnumMember.json("id")) - val rhsAst = astForNodeWithFunctionReference(tsEnumMember.json("initializer")) - val callNode_ = - callNode( - tsEnumMember, - tsEnumMember.code, - Operators.assignment, - DispatchTypes.STATIC_DISPATCH - ) - val argAsts = List(lhsAst, rhsAst) - Seq(callAst(callNode_, argAsts), Ast(memberNode_)) - else - Seq(Ast(memberNode_)) - end astsForEnumMember - - private def astForClassMember(classElement: Value, typeDeclNode: NewTypeDecl): Ast = - val nodeInfo = createBabelNodeInfo(classElement) - val typeFullName = typeFor(nodeInfo) - val memberNode_ = nodeInfo.node match - case TSDeclareMethod | TSDeclareFunction => - val function = createMethodDefinitionNode(nodeInfo) - val bindingNode = newBindingNode("", "", "") - diffGraph.addEdge(typeDeclNode, bindingNode, EdgeTypes.BINDS) - diffGraph.addEdge(bindingNode, function, EdgeTypes.REF) - addModifier(function, nodeInfo.json) - memberNode( - nodeInfo, - function.name, - nodeInfo.code, - typeFullName, - Seq(function.fullName) - ) - case ClassMethod | ClassPrivateMethod => - val function = createMethodAstAndNode(nodeInfo).methodNode - val bindingNode = newBindingNode("", "", "") - diffGraph.addEdge(typeDeclNode, bindingNode, EdgeTypes.BINDS) - diffGraph.addEdge(bindingNode, function, EdgeTypes.REF) - addModifier(function, nodeInfo.json) - memberNode( - nodeInfo, - function.name, - nodeInfo.code, - typeFullName, - Seq(function.fullName) - ) - case ExpressionStatement if isInitializedMember(classElement) => - val memberNodeInfo = - createBabelNodeInfo(nodeInfo.json("expression")("left")("property")) - val name = memberNodeInfo.code - memberNode(nodeInfo, name, nodeInfo.code, typeFullName) - case TSPropertySignature | ObjectProperty if hasKey(nodeInfo.json("key"), "name") => - val memberNodeInfo = createBabelNodeInfo(nodeInfo.json("key")) - val name = memberNodeInfo.json("name").str - memberNode(nodeInfo, name, nodeInfo.code, typeFullName) - case _ => - val name = nodeInfo.node match - case ClassProperty => code(nodeInfo.json("key")) - case ClassPrivateProperty => code(nodeInfo.json("key")("id")) - // TODO: name field most likely needs adjustment for other Babel AST types - case _ => nodeInfo.code - memberNode(nodeInfo, name, nodeInfo.code, typeFullName) - - addModifier(memberNode_, classElement) - diffGraph.addEdge(typeDeclNode, memberNode_, EdgeTypes.AST) - astsForDecorators(nodeInfo).foreach { decoratorAst => - Ast.storeInDiffGraph(decoratorAst, diffGraph) - decoratorAst.root.foreach(diffGraph.addEdge(memberNode_, _, EdgeTypes.AST)) - } - - if hasKey(nodeInfo.json, "value") && !nodeInfo.json("value").isNull then - val lhsAst = astForNode(nodeInfo.json("key")) - val rhsAst = astForNodeWithFunctionReference(nodeInfo.json("value")) - val callNode_ = - callNode( - nodeInfo, - nodeInfo.code, - Operators.assignment, - DispatchTypes.STATIC_DISPATCH - ) - val argAsts = List(lhsAst, rhsAst) - callAst(callNode_, argAsts) - else - Ast() - end astForClassMember - - protected def astForEnum(tsEnum: BabelNodeInfo): Ast = - val (typeName, typeFullName) = calcTypeNameAndFullName(tsEnum) - registerType(typeName, typeFullName) - - val astParentType = methodAstParentStack.head.label - val astParentFullName = methodAstParentStack.head.properties("FULL_NAME").toString - - val typeDeclNode_ = typeDeclNode( - tsEnum, - typeName, - typeFullName, - parserResult.filename, - s"enum $typeName", - astParentType, - astParentFullName - ) - seenAliasTypes.add(typeDeclNode_) - - addModifier(typeDeclNode_, tsEnum.json) - - val typeRefNode_ = typeRefNode(tsEnum, s"enum $typeName", typeFullName) - - methodAstParentStack.push(typeDeclNode_) - dynamicInstanceTypeStack.push(typeFullName) - typeRefIdStack.push(typeRefNode_) - scope.pushNewMethodScope(typeFullName, typeName, typeDeclNode_, None) - - val memberAsts = tsEnum.json("members").arr.toList.flatMap(m => - astsForEnumMember(createBabelNodeInfo(m)) - ) - - methodAstParentStack.pop() - dynamicInstanceTypeStack.pop() - typeRefIdStack.pop() - scope.popScope() - - val (calls, member) = - memberAsts.partition(_.nodes.headOption.exists(_.isInstanceOf[NewCall])) - if calls.isEmpty then - Ast.storeInDiffGraph(Ast(typeDeclNode_).withChildren(member), diffGraph) - else - val init = - staticInitMethodAst( - calls, - s"$typeFullName:${io.appthreat.x2cpg.Defines.StaticInitMethodName}", - None, - Defines.Any - ) - Ast.storeInDiffGraph(Ast(typeDeclNode_).withChildren(member).withChild(init), diffGraph) - - diffGraph.addEdge(methodAstParentStack.head, typeDeclNode_, EdgeTypes.AST) - Ast(typeRefNode_) - end astForEnum - - private def isStaticMember(json: Value): Boolean = - val nodeInfo = createBabelNodeInfo(json).node - val isStatic = safeBool(json, "static").contains(true) - nodeInfo != ClassMethod && nodeInfo != ClassPrivateMethod && isStatic - - private def isInitializedMember(json: Value): Boolean = - val hasInitializedValue = hasKey(json, "value") && !json("value").isNull - val isAssignment = createBabelNodeInfo(json) match - case node if node.node == ExpressionStatement => - val exprNode = createBabelNodeInfo(node.json("expression")) - exprNode.node == AssignmentExpression && - createBabelNodeInfo(exprNode.json("left")).node == MemberExpression && - code(exprNode.json("left")("object")) == "this" - case _ => false - hasInitializedValue || isAssignment - - private def isStaticInitBlock(json: Value): Boolean = - createBabelNodeInfo(json).node == StaticBlock - - private def isClassMethodOrUninitializedMember(json: Value): Boolean = - val nodeInfo = createBabelNodeInfo(json).node - !isStaticInitBlock(json) && - (nodeInfo == ClassMethod || nodeInfo == ClassPrivateMethod || !isInitializedMember(json)) - - private def isClassMethodOrUninitializedMemberOrObjectProperty(json: Value): Boolean = - val nodeInfo = createBabelNodeInfo(json).node - !isStaticInitBlock(json) && - (nodeInfo == ObjectProperty || nodeInfo == ClassMethod || nodeInfo == ClassPrivateMethod || !isInitializedMember( - json - )) - - protected def astForClass( - clazz: BabelNodeInfo, - shouldCreateAssignmentCall: Boolean = false - ): Ast = - val (typeName, typeFullName) = calcTypeNameAndFullName(clazz) - registerType(typeName, typeFullName) - - val astParentType = methodAstParentStack.head.label - val astParentFullName = methodAstParentStack.head.properties("FULL_NAME").toString - - val superClass = Try(createBabelNodeInfo(clazz.json("superClass")).code).toOption.toSeq - val implements = Try( - clazz.json("implements").arr.map(createBabelNodeInfo(_).code) - ).toOption.toSeq.flatten - val mixins = - Try(clazz.json("mixins").arr.map(createBabelNodeInfo(_).code)).toOption.toSeq.flatten - - val typeDeclNode_ = typeDeclNode( - clazz, - typeName, - typeFullName, - parserResult.filename, - s"class $typeName", - astParentType, - astParentFullName, - inherits = superClass ++ implements ++ mixins - ) - seenAliasTypes.add(typeDeclNode_) - - addModifier(typeDeclNode_, clazz.json) - astsForDecorators(clazz).foreach { decoratorAst => - Ast.storeInDiffGraph(decoratorAst, diffGraph) - decoratorAst.root.foreach(diffGraph.addEdge(typeDeclNode_, _, EdgeTypes.AST)) - } - - diffGraph.addEdge(methodAstParentStack.head, typeDeclNode_, EdgeTypes.AST) - - val typeRefNode_ = typeRefNode(clazz, s"class $typeName", typeFullName) - - methodAstParentStack.push(typeDeclNode_) - dynamicInstanceTypeStack.push(typeFullName) - typeRefIdStack.push(typeRefNode_) - - scope.pushNewMethodScope(typeFullName, typeName, typeDeclNode_, None) - - val allClassMembers = classMembers(clazz, withConstructor = false).toList - - // adding all other members and retrieving their initialization calls - val memberInitCalls = allClassMembers - .filter(m => !isStaticMember(m) && isInitializedMember(m)) - .map(m => astForClassMember(m, typeDeclNode_)) - - val constructor = createClassConstructor(clazz, memberInitCalls) - val constructorNode = constructor.methodNode - - // adding all class methods / functions and uninitialized, non-static members - allClassMembers - .filter(member => isClassMethodOrUninitializedMember(member) && !isStaticMember(member)) - .map(m => astForClassMember(m, typeDeclNode_)) - - // adding all static members and retrieving their initialization calls - val staticMemberInitCalls = - allClassMembers.filter(isStaticMember).map(m => astForClassMember(m, typeDeclNode_)) - - // retrieving initialization calls from the static initialization block if any - val staticInitBlock = allClassMembers.find(isStaticInitBlock) - val staticInitBlockAsts = - staticInitBlock.map(block => - block("body").arr.toList.map(astForNodeWithFunctionReference) - ).getOrElse(List.empty) - - methodAstParentStack.pop() - dynamicInstanceTypeStack.pop() - typeRefIdStack.pop() - scope.popScope() - - if staticMemberInitCalls.nonEmpty || staticInitBlockAsts.nonEmpty then - val init = staticInitMethodAst( - staticMemberInitCalls ++ staticInitBlockAsts, - s"$typeFullName:${io.appthreat.x2cpg.Defines.StaticInitMethodName}", - None, - Defines.Any + diffGraph.addEdge( + result.methodNode, + NewModifier().modifierType(ModifierTypes.CONSTRUCTOR), + EdgeTypes.AST ) - Ast.storeInDiffGraph(init, diffGraph) - diffGraph.addEdge(typeDeclNode_, init.nodes.head, EdgeTypes.AST) - - if shouldCreateAssignmentCall then - diffGraph.addEdge(localAstParentStack.head, typeRefNode_, EdgeTypes.AST) - - // return a synthetic assignment to enable tracing of the implicitly created identifier for - // the class definition assigned to its constructor - val classIdNode = identifierNode(clazz, typeName, Seq(constructorNode.fullName)) - val constructorRefNode = - methodRefNode( - clazz, - constructorNode.code, - constructorNode.fullName, - constructorNode.fullName + result + case Some(classConstructor) => + val methodNode = + createMethodDefinitionNode( + createBabelNodeInfo(classConstructor), + methodBlockContent = constructorContent ) - - val idLocal = newLocalNode(typeName, Defines.Any).order(0) - diffGraph.addEdge(localAstParentStack.head, idLocal, EdgeTypes.AST) - scope.addVariable(typeName, idLocal, BlockScope) - scope.addVariableReference(typeName, classIdNode) - - createAssignmentCallAst( - classIdNode, - constructorRefNode, - s"$typeName = ${constructorNode.fullName}", - clazz.lineNumber, - clazz.columnNumber + diffGraph.addEdge( + methodNode, + NewModifier().modifierType(ModifierTypes.CONSTRUCTOR), + EdgeTypes.AST ) - else - Ast(typeRefNode_) - end if - end astForClass - - protected def addModifier(node: NewNode, json: Value): Unit = - createBabelNodeInfo(json).node match - case ClassPrivateProperty => - diffGraph.addEdge( - node, - NewModifier().modifierType(ModifierTypes.PRIVATE), - EdgeTypes.AST - ) - case _ => - if safeBool(json, "abstract").contains(true) then - diffGraph.addEdge( - node, - NewModifier().modifierType(ModifierTypes.ABSTRACT), - EdgeTypes.AST - ) - if safeBool(json, "static").contains(true) then - diffGraph.addEdge( - node, - NewModifier().modifierType(ModifierTypes.STATIC), - EdgeTypes.AST - ) - if safeStr(json, "accessibility").contains("public") then - diffGraph.addEdge( - node, - NewModifier().modifierType(ModifierTypes.PUBLIC), - EdgeTypes.AST - ) - if safeStr(json, "accessibility").contains("private") then - diffGraph.addEdge( - node, - NewModifier().modifierType(ModifierTypes.PRIVATE), - EdgeTypes.AST - ) - if safeStr(json, "accessibility").contains("protected") then - diffGraph.addEdge( - node, - NewModifier().modifierType(ModifierTypes.PROTECTED), - EdgeTypes.AST - ) - - protected def astForModule(tsModuleDecl: BabelNodeInfo): Ast = - val (name, fullName) = calcTypeNameAndFullName(tsModuleDecl) - val namespaceNode = NewNamespaceBlock() - .code(tsModuleDecl.code) - .lineNumber(tsModuleDecl.lineNumber) - .columnNumber(tsModuleDecl.columnNumber) - .filename(parserResult.filename) - .name(name) - .fullName(fullName) - - methodAstParentStack.push(namespaceNode) - dynamicInstanceTypeStack.push(fullName) - - scope.pushNewMethodScope(fullName, name, namespaceNode, None) - - val blockAst = if hasKey(tsModuleDecl.json, "body") then - val nodeInfo = createBabelNodeInfo(tsModuleDecl.json("body")) - nodeInfo.node match - case TSModuleDeclaration => astForModule(nodeInfo) - case _ => astForBlockStatement(nodeInfo) - else - Ast() - - methodAstParentStack.pop() - dynamicInstanceTypeStack.pop() - scope.popScope() - - Ast(namespaceNode).withChild(blockAst) - end astForModule - - protected def astForInterface(tsInterface: BabelNodeInfo): Ast = - val (typeName, typeFullName) = calcTypeNameAndFullName(tsInterface) - registerType(typeName, typeFullName) - - val astParentType = methodAstParentStack.head.label - val astParentFullName = methodAstParentStack.head.properties("FULL_NAME").toString - - val extendz = Try( - tsInterface.json("extends").arr.map(createBabelNodeInfo(_).code) - ).toOption.toSeq.flatten - - val typeDeclNode_ = typeDeclNode( - tsInterface, - typeName, - typeFullName, - parserResult.filename, - s"interface $typeName", - astParentType, - astParentFullName, - inherits = extendz - ) - seenAliasTypes.add(typeDeclNode_) - - addModifier(typeDeclNode_, tsInterface.json) - - methodAstParentStack.push(typeDeclNode_) - dynamicInstanceTypeStack.push(typeFullName) - - scope.pushNewMethodScope(typeFullName, typeName, typeDeclNode_, None) - - val constructorNode = interfaceConstructor(typeName, tsInterface) - diffGraph.addEdge( - constructorNode, - NewModifier().modifierType(ModifierTypes.CONSTRUCTOR), - EdgeTypes.AST - ) - - val constructorBindingNode = newBindingNode("", "", "") - diffGraph.addEdge(typeDeclNode_, constructorBindingNode, EdgeTypes.BINDS) - diffGraph.addEdge(constructorBindingNode, constructorNode, EdgeTypes.REF) - - val interfaceBodyElements = classMembers(tsInterface, withConstructor = false) - - interfaceBodyElements.foreach { classElement => - val nodeInfo = createBabelNodeInfo(classElement) - val typeFullName = typeFor(nodeInfo) - val memberNodes = nodeInfo.node match - case TSCallSignatureDeclaration | TSMethodSignature => - val functionNode = createMethodDefinitionNode(nodeInfo) - val bindingNode = newBindingNode("", "", "") - diffGraph.addEdge(typeDeclNode_, bindingNode, EdgeTypes.BINDS) - diffGraph.addEdge(bindingNode, functionNode, EdgeTypes.REF) - addModifier(functionNode, nodeInfo.json) - Seq(memberNode( - nodeInfo, - functionNode.name, - nodeInfo.code, - typeFullName, - Seq(functionNode.fullName) + MethodAst(Ast(methodNode), methodNode, Ast(methodNode)) + case _ => + val result = createFakeConstructor( + "constructor() {}", + classExpr, + methodBlockContent = constructorContent + ) + diffGraph.addEdge( + result.methodNode, + NewModifier().modifierType(ModifierTypes.CONSTRUCTOR), + EdgeTypes.AST + ) + result + + private def interfaceConstructor(typeName: String, tsInterface: BabelNodeInfo): NewMethod = + findClassConstructor(tsInterface) match + case Some(interfaceConstructor) => + createMethodDefinitionNode(createBabelNodeInfo(interfaceConstructor)) + case _ => createFakeConstructor(s"new: $typeName", tsInterface).methodNode + + private def astsForEnumMember(tsEnumMember: BabelNodeInfo): Seq[Ast] = + val name = code(tsEnumMember.json("id")) + val memberNode_ = memberNode(tsEnumMember, name, tsEnumMember.code, typeFor(tsEnumMember)) + addModifier(memberNode_, tsEnumMember.json) + + if hasKey(tsEnumMember.json, "initializer") then + val lhsAst = astForNode(tsEnumMember.json("id")) + val rhsAst = astForNodeWithFunctionReference(tsEnumMember.json("initializer")) + val callNode_ = + callNode( + tsEnumMember, + tsEnumMember.code, + Operators.assignment, + DispatchTypes.STATIC_DISPATCH + ) + val argAsts = List(lhsAst, rhsAst) + Seq(callAst(callNode_, argAsts), Ast(memberNode_)) + else + Seq(Ast(memberNode_)) + end astsForEnumMember + + private def astForClassMember(classElement: Value, typeDeclNode: NewTypeDecl): Ast = + val nodeInfo = createBabelNodeInfo(classElement) + val typeFullName = typeFor(nodeInfo) + val memberNode_ = nodeInfo.node match + case TSDeclareMethod | TSDeclareFunction => + val function = createMethodDefinitionNode(nodeInfo) + val bindingNode = newBindingNode("", "", "") + diffGraph.addEdge(typeDeclNode, bindingNode, EdgeTypes.BINDS) + diffGraph.addEdge(bindingNode, function, EdgeTypes.REF) + addModifier(function, nodeInfo.json) + memberNode( + nodeInfo, + function.name, + nodeInfo.code, + typeFullName, + Seq(function.fullName) + ) + case ClassMethod | ClassPrivateMethod => + val function = createMethodAstAndNode(nodeInfo).methodNode + val bindingNode = newBindingNode("", "", "") + diffGraph.addEdge(typeDeclNode, bindingNode, EdgeTypes.BINDS) + diffGraph.addEdge(bindingNode, function, EdgeTypes.REF) + addModifier(function, nodeInfo.json) + memberNode( + nodeInfo, + function.name, + nodeInfo.code, + typeFullName, + Seq(function.fullName) + ) + case ExpressionStatement if isInitializedMember(classElement) => + val memberNodeInfo = + createBabelNodeInfo(nodeInfo.json("expression")("left")("property")) + val name = memberNodeInfo.code + memberNode(nodeInfo, name, nodeInfo.code, typeFullName) + case TSPropertySignature | ObjectProperty if hasKey(nodeInfo.json("key"), "name") => + val memberNodeInfo = createBabelNodeInfo(nodeInfo.json("key")) + val name = memberNodeInfo.json("name").str + memberNode(nodeInfo, name, nodeInfo.code, typeFullName) + case _ => + val name = nodeInfo.node match + case ClassProperty => code(nodeInfo.json("key")) + case ClassPrivateProperty => code(nodeInfo.json("key")("id")) + // TODO: name field most likely needs adjustment for other Babel AST types + case _ => nodeInfo.code + memberNode(nodeInfo, name, nodeInfo.code, typeFullName) + + addModifier(memberNode_, classElement) + diffGraph.addEdge(typeDeclNode, memberNode_, EdgeTypes.AST) + astsForDecorators(nodeInfo).foreach { decoratorAst => + Ast.storeInDiffGraph(decoratorAst, diffGraph) + decoratorAst.root.foreach(diffGraph.addEdge(memberNode_, _, EdgeTypes.AST)) + } + + if hasKey(nodeInfo.json, "value") && !nodeInfo.json("value").isNull then + val lhsAst = astForNode(nodeInfo.json("key")) + val rhsAst = astForNodeWithFunctionReference(nodeInfo.json("value")) + val callNode_ = + callNode( + nodeInfo, + nodeInfo.code, + Operators.assignment, + DispatchTypes.STATIC_DISPATCH + ) + val argAsts = List(lhsAst, rhsAst) + callAst(callNode_, argAsts) + else + Ast() + end astForClassMember + + protected def astForEnum(tsEnum: BabelNodeInfo): Ast = + val (typeName, typeFullName) = calcTypeNameAndFullName(tsEnum) + registerType(typeName, typeFullName) + + val astParentType = methodAstParentStack.head.label + val astParentFullName = methodAstParentStack.head.properties("FULL_NAME").toString + + val typeDeclNode_ = typeDeclNode( + tsEnum, + typeName, + typeFullName, + parserResult.filename, + s"enum $typeName", + astParentType, + astParentFullName + ) + seenAliasTypes.add(typeDeclNode_) + + addModifier(typeDeclNode_, tsEnum.json) + + val typeRefNode_ = typeRefNode(tsEnum, s"enum $typeName", typeFullName) + + methodAstParentStack.push(typeDeclNode_) + dynamicInstanceTypeStack.push(typeFullName) + typeRefIdStack.push(typeRefNode_) + scope.pushNewMethodScope(typeFullName, typeName, typeDeclNode_, None) + + val memberAsts = tsEnum.json("members").arr.toList.flatMap(m => + astsForEnumMember(createBabelNodeInfo(m)) + ) + + methodAstParentStack.pop() + dynamicInstanceTypeStack.pop() + typeRefIdStack.pop() + scope.popScope() + + val (calls, member) = + memberAsts.partition(_.nodes.headOption.exists(_.isInstanceOf[NewCall])) + if calls.isEmpty then + Ast.storeInDiffGraph(Ast(typeDeclNode_).withChildren(member), diffGraph) + else + val init = + staticInitMethodAst( + calls, + s"$typeFullName:${io.appthreat.x2cpg.Defines.StaticInitMethodName}", + None, + Defines.Any + ) + Ast.storeInDiffGraph(Ast(typeDeclNode_).withChildren(member).withChild(init), diffGraph) + + diffGraph.addEdge(methodAstParentStack.head, typeDeclNode_, EdgeTypes.AST) + Ast(typeRefNode_) + end astForEnum + + private def isStaticMember(json: Value): Boolean = + val nodeInfo = createBabelNodeInfo(json).node + val isStatic = safeBool(json, "static").contains(true) + nodeInfo != ClassMethod && nodeInfo != ClassPrivateMethod && isStatic + + private def isInitializedMember(json: Value): Boolean = + val hasInitializedValue = hasKey(json, "value") && !json("value").isNull + val isAssignment = createBabelNodeInfo(json) match + case node if node.node == ExpressionStatement => + val exprNode = createBabelNodeInfo(node.json("expression")) + exprNode.node == AssignmentExpression && + createBabelNodeInfo(exprNode.json("left")).node == MemberExpression && + code(exprNode.json("left")("object")) == "this" + case _ => false + hasInitializedValue || isAssignment + + private def isStaticInitBlock(json: Value): Boolean = + createBabelNodeInfo(json).node == StaticBlock + + private def isClassMethodOrUninitializedMember(json: Value): Boolean = + val nodeInfo = createBabelNodeInfo(json).node + !isStaticInitBlock(json) && + (nodeInfo == ClassMethod || nodeInfo == ClassPrivateMethod || !isInitializedMember(json)) + + private def isClassMethodOrUninitializedMemberOrObjectProperty(json: Value): Boolean = + val nodeInfo = createBabelNodeInfo(json).node + !isStaticInitBlock(json) && + (nodeInfo == ObjectProperty || nodeInfo == ClassMethod || nodeInfo == ClassPrivateMethod || !isInitializedMember( + json + )) + + protected def astForClass( + clazz: BabelNodeInfo, + shouldCreateAssignmentCall: Boolean = false + ): Ast = + val (typeName, typeFullName) = calcTypeNameAndFullName(clazz) + registerType(typeName, typeFullName) + + val astParentType = methodAstParentStack.head.label + val astParentFullName = methodAstParentStack.head.properties("FULL_NAME").toString + + val superClass = Try(createBabelNodeInfo(clazz.json("superClass")).code).toOption.toSeq + val implements = Try( + clazz.json("implements").arr.map(createBabelNodeInfo(_).code) + ).toOption.toSeq.flatten + val mixins = + Try(clazz.json("mixins").arr.map(createBabelNodeInfo(_).code)).toOption.toSeq.flatten + + val typeDeclNode_ = typeDeclNode( + clazz, + typeName, + typeFullName, + parserResult.filename, + s"class $typeName", + astParentType, + astParentFullName, + inherits = superClass ++ implements ++ mixins + ) + seenAliasTypes.add(typeDeclNode_) + + addModifier(typeDeclNode_, clazz.json) + astsForDecorators(clazz).foreach { decoratorAst => + Ast.storeInDiffGraph(decoratorAst, diffGraph) + decoratorAst.root.foreach(diffGraph.addEdge(typeDeclNode_, _, EdgeTypes.AST)) + } + + diffGraph.addEdge(methodAstParentStack.head, typeDeclNode_, EdgeTypes.AST) + + val typeRefNode_ = typeRefNode(clazz, s"class $typeName", typeFullName) + + methodAstParentStack.push(typeDeclNode_) + dynamicInstanceTypeStack.push(typeFullName) + typeRefIdStack.push(typeRefNode_) + + scope.pushNewMethodScope(typeFullName, typeName, typeDeclNode_, None) + + val allClassMembers = classMembers(clazz, withConstructor = false).toList + + // adding all other members and retrieving their initialization calls + val memberInitCalls = allClassMembers + .filter(m => !isStaticMember(m) && isInitializedMember(m)) + .map(m => astForClassMember(m, typeDeclNode_)) + + val constructor = createClassConstructor(clazz, memberInitCalls) + val constructorNode = constructor.methodNode + + // adding all class methods / functions and uninitialized, non-static members + allClassMembers + .filter(member => isClassMethodOrUninitializedMember(member) && !isStaticMember(member)) + .map(m => astForClassMember(m, typeDeclNode_)) + + // adding all static members and retrieving their initialization calls + val staticMemberInitCalls = + allClassMembers.filter(isStaticMember).map(m => astForClassMember(m, typeDeclNode_)) + + // retrieving initialization calls from the static initialization block if any + val staticInitBlock = allClassMembers.find(isStaticInitBlock) + val staticInitBlockAsts = + staticInitBlock.map(block => + block("body").arr.toList.map(astForNodeWithFunctionReference) + ).getOrElse(List.empty) + + methodAstParentStack.pop() + dynamicInstanceTypeStack.pop() + typeRefIdStack.pop() + scope.popScope() + + if staticMemberInitCalls.nonEmpty || staticInitBlockAsts.nonEmpty then + val init = staticInitMethodAst( + staticMemberInitCalls ++ staticInitBlockAsts, + s"$typeFullName:${io.appthreat.x2cpg.Defines.StaticInitMethodName}", + None, + Defines.Any + ) + Ast.storeInDiffGraph(init, diffGraph) + diffGraph.addEdge(typeDeclNode_, init.nodes.head, EdgeTypes.AST) + + if shouldCreateAssignmentCall then + diffGraph.addEdge(localAstParentStack.head, typeRefNode_, EdgeTypes.AST) + + // return a synthetic assignment to enable tracing of the implicitly created identifier for + // the class definition assigned to its constructor + val classIdNode = identifierNode(clazz, typeName, Seq(constructorNode.fullName)) + val constructorRefNode = + methodRefNode( + clazz, + constructorNode.code, + constructorNode.fullName, + constructorNode.fullName + ) + + val idLocal = newLocalNode(typeName, Defines.Any).order(0) + diffGraph.addEdge(localAstParentStack.head, idLocal, EdgeTypes.AST) + scope.addVariable(typeName, idLocal, BlockScope) + scope.addVariableReference(typeName, classIdNode) + + createAssignmentCallAst( + classIdNode, + constructorRefNode, + s"$typeName = ${constructorNode.fullName}", + clazz.lineNumber, + clazz.columnNumber + ) + else + Ast(typeRefNode_) + end if + end astForClass + + protected def addModifier(node: NewNode, json: Value): Unit = + createBabelNodeInfo(json).node match + case ClassPrivateProperty => + diffGraph.addEdge( + node, + NewModifier().modifierType(ModifierTypes.PRIVATE), + EdgeTypes.AST + ) + case _ => + if safeBool(json, "abstract").contains(true) then + diffGraph.addEdge( + node, + NewModifier().modifierType(ModifierTypes.ABSTRACT), + EdgeTypes.AST + ) + if safeBool(json, "static").contains(true) then + diffGraph.addEdge( + node, + NewModifier().modifierType(ModifierTypes.STATIC), + EdgeTypes.AST + ) + if safeStr(json, "accessibility").contains("public") then + diffGraph.addEdge( + node, + NewModifier().modifierType(ModifierTypes.PUBLIC), + EdgeTypes.AST + ) + if safeStr(json, "accessibility").contains("private") then + diffGraph.addEdge( + node, + NewModifier().modifierType(ModifierTypes.PRIVATE), + EdgeTypes.AST + ) + if safeStr(json, "accessibility").contains("protected") then + diffGraph.addEdge( + node, + NewModifier().modifierType(ModifierTypes.PROTECTED), + EdgeTypes.AST + ) + + protected def astForModule(tsModuleDecl: BabelNodeInfo): Ast = + val (name, fullName) = calcTypeNameAndFullName(tsModuleDecl) + val namespaceNode = NewNamespaceBlock() + .code(tsModuleDecl.code) + .lineNumber(tsModuleDecl.lineNumber) + .columnNumber(tsModuleDecl.columnNumber) + .filename(parserResult.filename) + .name(name) + .fullName(fullName) + + methodAstParentStack.push(namespaceNode) + dynamicInstanceTypeStack.push(fullName) + + scope.pushNewMethodScope(fullName, name, namespaceNode, None) + + val blockAst = if hasKey(tsModuleDecl.json, "body") then + val nodeInfo = createBabelNodeInfo(tsModuleDecl.json("body")) + nodeInfo.node match + case TSModuleDeclaration => astForModule(nodeInfo) + case _ => astForBlockStatement(nodeInfo) + else + Ast() + + methodAstParentStack.pop() + dynamicInstanceTypeStack.pop() + scope.popScope() + + Ast(namespaceNode).withChild(blockAst) + end astForModule + + protected def astForInterface(tsInterface: BabelNodeInfo): Ast = + val (typeName, typeFullName) = calcTypeNameAndFullName(tsInterface) + registerType(typeName, typeFullName) + + val astParentType = methodAstParentStack.head.label + val astParentFullName = methodAstParentStack.head.properties("FULL_NAME").toString + + val extendz = Try( + tsInterface.json("extends").arr.map(createBabelNodeInfo(_).code) + ).toOption.toSeq.flatten + + val typeDeclNode_ = typeDeclNode( + tsInterface, + typeName, + typeFullName, + parserResult.filename, + s"interface $typeName", + astParentType, + astParentFullName, + inherits = extendz + ) + seenAliasTypes.add(typeDeclNode_) + + addModifier(typeDeclNode_, tsInterface.json) + + methodAstParentStack.push(typeDeclNode_) + dynamicInstanceTypeStack.push(typeFullName) + + scope.pushNewMethodScope(typeFullName, typeName, typeDeclNode_, None) + + val constructorNode = interfaceConstructor(typeName, tsInterface) + diffGraph.addEdge( + constructorNode, + NewModifier().modifierType(ModifierTypes.CONSTRUCTOR), + EdgeTypes.AST + ) + + val constructorBindingNode = newBindingNode("", "", "") + diffGraph.addEdge(typeDeclNode_, constructorBindingNode, EdgeTypes.BINDS) + diffGraph.addEdge(constructorBindingNode, constructorNode, EdgeTypes.REF) + + val interfaceBodyElements = classMembers(tsInterface, withConstructor = false) + + interfaceBodyElements.foreach { classElement => + val nodeInfo = createBabelNodeInfo(classElement) + val typeFullName = typeFor(nodeInfo) + val memberNodes = nodeInfo.node match + case TSCallSignatureDeclaration | TSMethodSignature => + val functionNode = createMethodDefinitionNode(nodeInfo) + val bindingNode = newBindingNode("", "", "") + diffGraph.addEdge(typeDeclNode_, bindingNode, EdgeTypes.BINDS) + diffGraph.addEdge(bindingNode, functionNode, EdgeTypes.REF) + addModifier(functionNode, nodeInfo.json) + Seq(memberNode( + nodeInfo, + functionNode.name, + nodeInfo.code, + typeFullName, + Seq(functionNode.fullName) + )) + case _ => + val names = nodeInfo.node match + case TSPropertySignature | TSMethodSignature => + if hasKey(nodeInfo.json("key"), "value") then + Seq(safeStr(nodeInfo.json("key"), "value").getOrElse( + code(nodeInfo.json("key")("value")) )) - case _ => - val names = nodeInfo.node match - case TSPropertySignature | TSMethodSignature => - if hasKey(nodeInfo.json("key"), "value") then - Seq(safeStr(nodeInfo.json("key"), "value").getOrElse( - code(nodeInfo.json("key")("value")) - )) - else Seq(code(nodeInfo.json("key"))) - case TSIndexSignature => - nodeInfo.json("parameters").arr.toSeq.map(_("name").str) - // TODO: name field most likely needs adjustment for other Babel AST types - case _ => Seq(nodeInfo.code) - names.map { n => - val node = memberNode(nodeInfo, n, nodeInfo.code, typeFullName) - addModifier(node, nodeInfo.json) - node - } - memberNodes.foreach(diffGraph.addEdge(typeDeclNode_, _, EdgeTypes.AST)) - } - - methodAstParentStack.pop() - dynamicInstanceTypeStack.pop() - scope.popScope() - - Ast(typeDeclNode_) - end astForInterface + else Seq(code(nodeInfo.json("key"))) + case TSIndexSignature => + nodeInfo.json("parameters").arr.toSeq.map(_("name").str) + // TODO: name field most likely needs adjustment for other Babel AST types + case _ => Seq(nodeInfo.code) + names.map { n => + val node = memberNode(nodeInfo, n, nodeInfo.code, typeFullName) + addModifier(node, nodeInfo.json) + node + } + memberNodes.foreach(diffGraph.addEdge(typeDeclNode_, _, EdgeTypes.AST)) + } + + methodAstParentStack.pop() + dynamicInstanceTypeStack.pop() + scope.popScope() + + Ast(typeDeclNode_) + end astForInterface end AstForTypesCreator diff --git a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/astcreation/AstNodeBuilder.scala b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/astcreation/AstNodeBuilder.scala index 050584a3..c932336a 100644 --- a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/astcreation/AstNodeBuilder.scala +++ b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/astcreation/AstNodeBuilder.scala @@ -10,289 +10,289 @@ import io.shiftleft.codepropertygraph.generated.DispatchTypes import io.shiftleft.codepropertygraph.generated.Operators trait AstNodeBuilder(implicit withSchemaValidation: ValidationMode): - this: AstCreator => - protected def createMethodReturnNode(func: BabelNodeInfo): NewMethodReturn = - newMethodReturnNode(typeFor(func), line = func.lineNumber, column = func.columnNumber) + this: AstCreator => + protected def createMethodReturnNode(func: BabelNodeInfo): NewMethodReturn = + newMethodReturnNode(typeFor(func), line = func.lineNumber, column = func.columnNumber) - protected def setOrderExplicitly(ast: Ast, order: Int): Unit = - ast.root.foreach { case expr: ExpressionNew => expr.order = order } + protected def setOrderExplicitly(ast: Ast, order: Int): Unit = + ast.root.foreach { case expr: ExpressionNew => expr.order = order } - protected def createJumpTarget(switchCase: BabelNodeInfo): NewJumpTarget = - val (switchName, switchCode) = if switchCase.json("test").isNull then - ("default", "default:") - else - ("case", s"case ${code(switchCase.json("test"))}:") - NewJumpTarget() - .parserTypeName(switchCase.node.toString) - .name(switchName) - .code(switchCode) - .lineNumber(switchCase.lineNumber) - .columnNumber(switchCase.columnNumber) + protected def createJumpTarget(switchCase: BabelNodeInfo): NewJumpTarget = + val (switchName, switchCode) = if switchCase.json("test").isNull then + ("default", "default:") + else + ("case", s"case ${code(switchCase.json("test"))}:") + NewJumpTarget() + .parserTypeName(switchCase.node.toString) + .name(switchName) + .code(switchCode) + .lineNumber(switchCase.lineNumber) + .columnNumber(switchCase.columnNumber) - protected def createControlStructureNode( - node: BabelNodeInfo, - controlStructureType: String - ): NewControlStructure = - val line = node.lineNumber - val column = node.columnNumber - val code = node.code - NewControlStructure() - .controlStructureType(controlStructureType) - .code(code) - .lineNumber(line) - .columnNumber(column) + protected def createControlStructureNode( + node: BabelNodeInfo, + controlStructureType: String + ): NewControlStructure = + val line = node.lineNumber + val column = node.columnNumber + val code = node.code + NewControlStructure() + .controlStructureType(controlStructureType) + .code(code) + .lineNumber(line) + .columnNumber(column) - protected def codeOf(node: NewNode): String = node match - case astNodeNew: AstNodeNew => astNodeNew.code - case _ => "" + protected def codeOf(node: NewNode): String = node match + case astNodeNew: AstNodeNew => astNodeNew.code + case _ => "" - protected def createIndexAccessCallAst( - baseNode: NewNode, - partNode: NewNode, - line: Option[Integer], - column: Option[Integer] - ): Ast = - val callNode = createCallNode( - s"${codeOf(baseNode)}[${codeOf(partNode)}]", - Operators.indexAccess, - DispatchTypes.STATIC_DISPATCH, - line, - column - ) - val arguments = List(Ast(baseNode), Ast(partNode)) - callAst(callNode, arguments) + protected def createIndexAccessCallAst( + baseNode: NewNode, + partNode: NewNode, + line: Option[Integer], + column: Option[Integer] + ): Ast = + val callNode = createCallNode( + s"${codeOf(baseNode)}[${codeOf(partNode)}]", + Operators.indexAccess, + DispatchTypes.STATIC_DISPATCH, + line, + column + ) + val arguments = List(Ast(baseNode), Ast(partNode)) + callAst(callNode, arguments) - protected def createIndexAccessCallAst( - baseAst: Ast, - partAst: Ast, - line: Option[Integer], - column: Option[Integer] - ): Ast = - val callNode = createCallNode( - s"${codeOf(baseAst.nodes.head)}[${codeOf(partAst.nodes.head)}]", - Operators.indexAccess, - DispatchTypes.STATIC_DISPATCH, - line, - column - ) - val arguments = List(baseAst, partAst) - callAst(callNode, arguments) + protected def createIndexAccessCallAst( + baseAst: Ast, + partAst: Ast, + line: Option[Integer], + column: Option[Integer] + ): Ast = + val callNode = createCallNode( + s"${codeOf(baseAst.nodes.head)}[${codeOf(partAst.nodes.head)}]", + Operators.indexAccess, + DispatchTypes.STATIC_DISPATCH, + line, + column + ) + val arguments = List(baseAst, partAst) + callAst(callNode, arguments) - protected def createFieldAccessCallAst( - baseNode: NewNode, - partNode: NewNode, - line: Option[Integer], - column: Option[Integer] - ): Ast = - val callNode = createCallNode( - s"${codeOf(baseNode)}.${codeOf(partNode)}", - Operators.fieldAccess, - DispatchTypes.STATIC_DISPATCH, - line, - column - ) - val arguments = List(Ast(baseNode), Ast(partNode)) - callAst(callNode, arguments) - - protected def createFieldAccessCallAst( - baseAst: Ast, - partNode: NewNode, - line: Option[Integer], - column: Option[Integer] - ): Ast = - val callNode = createCallNode( - s"${codeOf(baseAst.nodes.head)}.${codeOf(partNode)}", - Operators.fieldAccess, - DispatchTypes.STATIC_DISPATCH, - line, - column - ) - val arguments = List(baseAst, Ast(partNode)) - callAst(callNode, arguments) + protected def createFieldAccessCallAst( + baseNode: NewNode, + partNode: NewNode, + line: Option[Integer], + column: Option[Integer] + ): Ast = + val callNode = createCallNode( + s"${codeOf(baseNode)}.${codeOf(partNode)}", + Operators.fieldAccess, + DispatchTypes.STATIC_DISPATCH, + line, + column + ) + val arguments = List(Ast(baseNode), Ast(partNode)) + callAst(callNode, arguments) - protected def createTernaryCallAst( - testAst: Ast, - trueAst: Ast, - falseAst: Ast, - line: Option[Integer], - column: Option[Integer] - ): Ast = - val code = - s"${codeOf(testAst.nodes.head)} ? ${codeOf(trueAst.nodes.head)} : ${codeOf(falseAst.nodes.head)}" - val callNode = - createCallNode(code, Operators.conditional, DispatchTypes.STATIC_DISPATCH, line, column) - val arguments = List(testAst, trueAst, falseAst) - callAst(callNode, arguments) + protected def createFieldAccessCallAst( + baseAst: Ast, + partNode: NewNode, + line: Option[Integer], + column: Option[Integer] + ): Ast = + val callNode = createCallNode( + s"${codeOf(baseAst.nodes.head)}.${codeOf(partNode)}", + Operators.fieldAccess, + DispatchTypes.STATIC_DISPATCH, + line, + column + ) + val arguments = List(baseAst, Ast(partNode)) + callAst(callNode, arguments) - def callNode(node: BabelNodeInfo, code: String, name: String, dispatchType: String): NewCall = - val fullName = - if dispatchType == DispatchTypes.STATIC_DISPATCH then name - else x2cpg.Defines.DynamicCallUnknownFullName - callNode(node, code, name, fullName, dispatchType, None, Some(Defines.Any)) + protected def createTernaryCallAst( + testAst: Ast, + trueAst: Ast, + falseAst: Ast, + line: Option[Integer], + column: Option[Integer] + ): Ast = + val code = + s"${codeOf(testAst.nodes.head)} ? ${codeOf(trueAst.nodes.head)} : ${codeOf(falseAst.nodes.head)}" + val callNode = + createCallNode(code, Operators.conditional, DispatchTypes.STATIC_DISPATCH, line, column) + val arguments = List(testAst, trueAst, falseAst) + callAst(callNode, arguments) - private def createCallNode( - code: String, - callName: String, - dispatchType: String, - line: Option[Integer], - column: Option[Integer] - ): NewCall = NewCall() - .code(code) - .name(callName) - .methodFullName( - if dispatchType == DispatchTypes.STATIC_DISPATCH then callName - else x2cpg.Defines.DynamicCallUnknownFullName - ) - .dispatchType(dispatchType) - .lineNumber(line) - .columnNumber(column) - .typeFullName(Defines.Any) + def callNode(node: BabelNodeInfo, code: String, name: String, dispatchType: String): NewCall = + val fullName = + if dispatchType == DispatchTypes.STATIC_DISPATCH then name + else x2cpg.Defines.DynamicCallUnknownFullName + callNode(node, code, name, fullName, dispatchType, None, Some(Defines.Any)) - protected def createVoidCallNode(line: Option[Integer], column: Option[Integer]): NewCall = - createCallNode("void 0", ".void", DispatchTypes.STATIC_DISPATCH, line, column) + private def createCallNode( + code: String, + callName: String, + dispatchType: String, + line: Option[Integer], + column: Option[Integer] + ): NewCall = NewCall() + .code(code) + .name(callName) + .methodFullName( + if dispatchType == DispatchTypes.STATIC_DISPATCH then callName + else x2cpg.Defines.DynamicCallUnknownFullName + ) + .dispatchType(dispatchType) + .lineNumber(line) + .columnNumber(column) + .typeFullName(Defines.Any) - protected def createFieldIdentifierNode( - name: String, - line: Option[Integer], - column: Option[Integer] - ): NewFieldIdentifier = - val cleanedName = stripQuotes(name) - NewFieldIdentifier() - .code(cleanedName) - .canonicalName(cleanedName) - .lineNumber(line) - .columnNumber(column) + protected def createVoidCallNode(line: Option[Integer], column: Option[Integer]): NewCall = + createCallNode("void 0", ".void", DispatchTypes.STATIC_DISPATCH, line, column) - protected def literalNode( - node: BabelNodeInfo, - code: String, - dynamicTypeOption: Option[String] - ): NewLiteral = - val typeFullName = dynamicTypeOption match - case Some(value) if Defines.JsTypes.contains(value) => value - case _ => Defines.Any - literalNode(node, code, typeFullName, dynamicTypeOption.toList) + protected def createFieldIdentifierNode( + name: String, + line: Option[Integer], + column: Option[Integer] + ): NewFieldIdentifier = + val cleanedName = stripQuotes(name) + NewFieldIdentifier() + .code(cleanedName) + .canonicalName(cleanedName) + .lineNumber(line) + .columnNumber(column) - protected def createEqualsCallAst( - dest: Ast, - source: Ast, - line: Option[Integer], - column: Option[Integer] - ): Ast = - val code = s"${codeOf(dest.nodes.head)} === ${codeOf(source.nodes.head)}" - val callNode = - createCallNode(code, Operators.equals, DispatchTypes.STATIC_DISPATCH, line, column) - val arguments = List(dest, source) - callAst(callNode, arguments) + protected def literalNode( + node: BabelNodeInfo, + code: String, + dynamicTypeOption: Option[String] + ): NewLiteral = + val typeFullName = dynamicTypeOption match + case Some(value) if Defines.JsTypes.contains(value) => value + case _ => Defines.Any + literalNode(node, code, typeFullName, dynamicTypeOption.toList) - protected def createAssignmentCallAst( - destId: NewNode, - sourceId: NewNode, - code: String, - line: Option[Integer], - column: Option[Integer] - ): Ast = - val callNode = - createCallNode(code, Operators.assignment, DispatchTypes.STATIC_DISPATCH, line, column) - val arguments = List(Ast(destId), Ast(sourceId)) - callAst(callNode, arguments) + protected def createEqualsCallAst( + dest: Ast, + source: Ast, + line: Option[Integer], + column: Option[Integer] + ): Ast = + val code = s"${codeOf(dest.nodes.head)} === ${codeOf(source.nodes.head)}" + val callNode = + createCallNode(code, Operators.equals, DispatchTypes.STATIC_DISPATCH, line, column) + val arguments = List(dest, source) + callAst(callNode, arguments) - protected def createAssignmentCallAst( - dest: Ast, - source: Ast, - code: String, - line: Option[Integer], - column: Option[Integer] - ): Ast = - val callNode = - createCallNode(code, Operators.assignment, DispatchTypes.STATIC_DISPATCH, line, column) - val arguments = List(dest, source) - callAst(callNode, arguments) + protected def createAssignmentCallAst( + destId: NewNode, + sourceId: NewNode, + code: String, + line: Option[Integer], + column: Option[Integer] + ): Ast = + val callNode = + createCallNode(code, Operators.assignment, DispatchTypes.STATIC_DISPATCH, line, column) + val arguments = List(Ast(destId), Ast(sourceId)) + callAst(callNode, arguments) - protected def identifierNode(node: BabelNodeInfo, name: String): NewIdentifier = - val dynamicInstanceTypeOption = name match - case "this" => typeHintForThisExpression(Option(node)).headOption - case "console" => Option(Defines.Console) - case "Math" => Option(Defines.Math) - case _ => None - identifierNode(node, name, name, Defines.Any, dynamicInstanceTypeOption.toList) + protected def createAssignmentCallAst( + dest: Ast, + source: Ast, + code: String, + line: Option[Integer], + column: Option[Integer] + ): Ast = + val callNode = + createCallNode(code, Operators.assignment, DispatchTypes.STATIC_DISPATCH, line, column) + val arguments = List(dest, source) + callAst(callNode, arguments) - protected def identifierNode( - node: BabelNodeInfo, - name: String, - dynamicTypeHints: Seq[String] - ): NewIdentifier = - identifierNode(node, name, name, Defines.Any, dynamicTypeHints) + protected def identifierNode(node: BabelNodeInfo, name: String): NewIdentifier = + val dynamicInstanceTypeOption = name match + case "this" => typeHintForThisExpression(Option(node)).headOption + case "console" => Option(Defines.Console) + case "Math" => Option(Defines.Math) + case _ => None + identifierNode(node, name, name, Defines.Any, dynamicInstanceTypeOption.toList) - protected def createStaticCallNode( - code: String, - callName: String, - fullName: String, - line: Option[Integer], - column: Option[Integer] - ): NewCall = NewCall() - .code(code) - .name(callName) - .methodFullName(fullName) - .dispatchType(DispatchTypes.STATIC_DISPATCH) - .signature("") - .lineNumber(line) - .columnNumber(column) - .typeFullName(Defines.Any) + protected def identifierNode( + node: BabelNodeInfo, + name: String, + dynamicTypeHints: Seq[String] + ): NewIdentifier = + identifierNode(node, name, name, Defines.Any, dynamicTypeHints) - protected def createTemplateDomNode( - name: String, - code: String, - line: Option[Integer], - column: Option[Integer] - ): NewTemplateDom = - NewTemplateDom() - .name(name) - .code(code) - .lineNumber(line) - .columnNumber(column) + protected def createStaticCallNode( + code: String, + callName: String, + fullName: String, + line: Option[Integer], + column: Option[Integer] + ): NewCall = NewCall() + .code(code) + .name(callName) + .methodFullName(fullName) + .dispatchType(DispatchTypes.STATIC_DISPATCH) + .signature("") + .lineNumber(line) + .columnNumber(column) + .typeFullName(Defines.Any) - protected def createBlockNode( - node: BabelNodeInfo, - customCode: Option[String] = None - ): NewBlock = - NewBlock() - .typeFullName(Defines.Any) - .code(customCode.getOrElse(node.code)) - .lineNumber(node.lineNumber) - .columnNumber(node.columnNumber) + protected def createTemplateDomNode( + name: String, + code: String, + line: Option[Integer], + column: Option[Integer] + ): NewTemplateDom = + NewTemplateDom() + .name(name) + .code(code) + .lineNumber(line) + .columnNumber(column) - protected def createFunctionTypeAndTypeDeclAst( - node: BabelNodeInfo, - methodNode: NewMethod, - parentNode: NewNode, - methodName: String, - methodFullName: String, - filename: String - ): Ast = - registerType(methodName, methodFullName) + protected def createBlockNode( + node: BabelNodeInfo, + customCode: Option[String] = None + ): NewBlock = + NewBlock() + .typeFullName(Defines.Any) + .code(customCode.getOrElse(node.code)) + .lineNumber(node.lineNumber) + .columnNumber(node.columnNumber) - val astParentType = parentNode.label - val astParentFullName = parentNode.properties("FULL_NAME").toString - val functionTypeDeclNode = - typeDeclNode( - node, - methodName, - methodFullName, - filename, - methodName, - astParentType = astParentType, - astParentFullName = astParentFullName, - List(Defines.Any) - ) + protected def createFunctionTypeAndTypeDeclAst( + node: BabelNodeInfo, + methodNode: NewMethod, + parentNode: NewNode, + methodName: String, + methodFullName: String, + filename: String + ): Ast = + registerType(methodName, methodFullName) - // Problem for https://github.com/ShiftLeftSecurity/codescience/issues/3626 here. - // As the type (thus, the signature) of the function node is unknown (i.e., ANY*) - // we can't generate the correct binding with signature. - val bindingNode = NewBinding().name("").signature("") - Ast(functionTypeDeclNode).withBindsEdge(functionTypeDeclNode, bindingNode).withRefEdge( - bindingNode, - methodNode + val astParentType = parentNode.label + val astParentFullName = parentNode.properties("FULL_NAME").toString + val functionTypeDeclNode = + typeDeclNode( + node, + methodName, + methodFullName, + filename, + methodName, + astParentType = astParentType, + astParentFullName = astParentFullName, + List(Defines.Any) ) - end createFunctionTypeAndTypeDeclAst + + // Problem for https://github.com/ShiftLeftSecurity/codescience/issues/3626 here. + // As the type (thus, the signature) of the function node is unknown (i.e., ANY*) + // we can't generate the correct binding with signature. + val bindingNode = NewBinding().name("").signature("") + Ast(functionTypeDeclNode).withBindsEdge(functionTypeDeclNode, bindingNode).withRefEdge( + bindingNode, + methodNode + ) + end createFunctionTypeAndTypeDeclAst end AstNodeBuilder diff --git a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/astcreation/TypeHelper.scala b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/astcreation/TypeHelper.scala index 7f84e1f3..982c04d0 100644 --- a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/astcreation/TypeHelper.scala +++ b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/astcreation/TypeHelper.scala @@ -7,137 +7,137 @@ import io.appthreat.jssrc2cpg.parser.BabelAst.* import java.util.regex.Pattern trait TypeHelper: - this: AstCreator => - - private val TypeAnnotationKey = "typeAnnotation" - private val ReturnTypeKey = "returnType" - private val ImportMatcher = Pattern.compile("(typeof )?import\\([\"'](.*)[\"']\\)") - - private val ArrayReplacements = Map( - "any[]" -> s"${Defines.Any}[]", - "unknown[]" -> s"${Defines.Unknown}[]", - "number[]" -> s"${Defines.Number}[]", - "string[]" -> s"${Defines.String}[]", - "boolean[]" -> s"${Defines.Boolean}[]" - ) - - private val TypeReplacements = Map( - " any" -> s" ${Defines.Any}", - " unknown" -> s" ${Defines.Unknown}", - " number" -> s" ${Defines.Number}", - " null" -> s" ${Defines.Null}", - " string" -> s" ${Defines.String}", - " boolean" -> s" ${Defines.Boolean}", - " bigint" -> s" ${Defines.BigInt}", - "{}" -> Defines.Object, - "typeof " -> "" - ) - - protected def isPlainTypeAlias(alias: BabelNodeInfo): Boolean = - if hasKey(alias.json, "right") then - createBabelNodeInfo(alias.json("right")).node.toString == TSTypeReference.toString - else - createBabelNodeInfo( - alias.json("typeAnnotation") - ).node.toString == TSTypeReference.toString - - private def typeForFlowType(flowType: BabelNodeInfo): String = flowType.node match - case BooleanTypeAnnotation => Defines.Boolean - case NumberTypeAnnotation => Defines.Number - case ObjectTypeAnnotation => Defines.Object - case StringTypeAnnotation => Defines.String - case SymbolTypeAnnotation => Defines.Symbol - case NumberLiteralTypeAnnotation => code(flowType.json) - case ArrayTypeAnnotation => code(flowType.json) - case BooleanLiteralTypeAnnotation => code(flowType.json) - case NullLiteralTypeAnnotation => code(flowType.json) - case StringLiteralTypeAnnotation => code(flowType.json) - case GenericTypeAnnotation => code(flowType.json("id")) - case ThisTypeAnnotation => - typeHintForThisExpression(Option(flowType)).headOption.getOrElse(Defines.Any); - case NullableTypeAnnotation => - typeForTypeAnnotation(createBabelNodeInfo(flowType.json(TypeAnnotationKey))) - case _ => Defines.Any - - private def typeForTsType(tsType: BabelNodeInfo): String = tsType.node match - case TSBooleanKeyword => Defines.Boolean - case TSBigIntKeyword => Defines.Number - case TSNullKeyword => Defines.Null - case TSNumberKeyword => Defines.Number - case TSObjectKeyword => Defines.Object - case TSStringKeyword => Defines.String - case TSSymbolKeyword => Defines.Symbol - case TSUnknownKeyword => Defines.Unknown - case TSVoidKeyword => Defines.Void - case TSUndefinedKeyword => Defines.Undefined - case TSNeverKeyword => Defines.Never - case TSIntrinsicKeyword => code(tsType.json) - case TSTypeReference => code(tsType.json) - case TSArrayType => code(tsType.json) - case TSThisType => - typeHintForThisExpression(Option(tsType)).headOption.getOrElse(Defines.Any) - case TSOptionalType => - typeForTypeAnnotation(createBabelNodeInfo(tsType.json(TypeAnnotationKey))) - case TSRestType => - typeForTypeAnnotation(createBabelNodeInfo(tsType.json(TypeAnnotationKey))) - case TSParenthesizedType => - typeForTypeAnnotation(createBabelNodeInfo(tsType.json(TypeAnnotationKey))) - case _ => Defines.Any - - private def typeForTypeAnnotation(typeAnnotation: BabelNodeInfo): String = - typeAnnotation.node match - case TypeAnnotation => - typeForFlowType(createBabelNodeInfo(typeAnnotation.json(TypeAnnotationKey))) - case TSTypeAnnotation => - typeForTsType(createBabelNodeInfo(typeAnnotation.json(TypeAnnotationKey))) - case _: FlowType => typeForFlowType(createBabelNodeInfo(typeAnnotation.json)) - case _: TSType => typeForTsType(createBabelNodeInfo(typeAnnotation.json)) - case _ => Defines.Any - - private def isStringType(tpe: String): Boolean = - tpe.startsWith("\"") && tpe.endsWith("\"") - - private def isNumberType(tpe: String): Boolean = - tpe.toDoubleOption.isDefined - - private def typeFromTypeMap(node: BabelNodeInfo): String = - pos(node.json).flatMap(parserResult.typeMap.get) match - case Some(value) if value.isEmpty => Defines.String - case Some(value) if value == "string" => Defines.String - case Some(value) if isStringType(value) => Defines.String - case Some(value) if value == "number" => Defines.Number - case Some(value) if isNumberType(value) => Defines.Number - case Some(value) if value == "null" => Defines.Null - case Some(value) if value == "boolean" => Defines.Boolean - case Some(value) if value == "any" => Defines.Any - case Some(value) if ImportMatcher.matcher(value).matches() => importToModule(value) - case Some(other) => - (TypeReplacements ++ ArrayReplacements).foldLeft(other) { case (typeStr, (m, r)) => - typeStr.replace(m, r) - } - case None => Defines.Any - - private def importToModule(value: String): String = - val matcher = ImportMatcher.matcher(value) - this.rootTypeDecl.headOption match - case Some(typeDecl) => typeDecl.fullName - case None if matcher.matches() => - matcher.group(2).stripSuffix(".js").concat(".js::program") - case None => value - - protected def typeFor(node: BabelNodeInfo): String = - val tpe = Seq(TypeAnnotationKey, ReturnTypeKey).find(hasKey(node.json, _)) match - case Some(key) => typeForTypeAnnotation(createBabelNodeInfo(node.json(key))) - case None => typeFromTypeMap(node) - registerType(tpe, tpe) - tpe - - protected def typeHintForThisExpression(node: Option[BabelNodeInfo] = None): Seq[String] = - dynamicInstanceTypeStack.headOption match - case Some(tpe) => Seq(tpe) - case None if node.isDefined => - typeFor(node.get) match - case t if t != Defines.Any && t != "this" => Seq(t) - case _ => rootTypeDecl.map(_.fullName).toSeq - case None => rootTypeDecl.map(_.fullName).toSeq + this: AstCreator => + + private val TypeAnnotationKey = "typeAnnotation" + private val ReturnTypeKey = "returnType" + private val ImportMatcher = Pattern.compile("(typeof )?import\\([\"'](.*)[\"']\\)") + + private val ArrayReplacements = Map( + "any[]" -> s"${Defines.Any}[]", + "unknown[]" -> s"${Defines.Unknown}[]", + "number[]" -> s"${Defines.Number}[]", + "string[]" -> s"${Defines.String}[]", + "boolean[]" -> s"${Defines.Boolean}[]" + ) + + private val TypeReplacements = Map( + " any" -> s" ${Defines.Any}", + " unknown" -> s" ${Defines.Unknown}", + " number" -> s" ${Defines.Number}", + " null" -> s" ${Defines.Null}", + " string" -> s" ${Defines.String}", + " boolean" -> s" ${Defines.Boolean}", + " bigint" -> s" ${Defines.BigInt}", + "{}" -> Defines.Object, + "typeof " -> "" + ) + + protected def isPlainTypeAlias(alias: BabelNodeInfo): Boolean = + if hasKey(alias.json, "right") then + createBabelNodeInfo(alias.json("right")).node.toString == TSTypeReference.toString + else + createBabelNodeInfo( + alias.json("typeAnnotation") + ).node.toString == TSTypeReference.toString + + private def typeForFlowType(flowType: BabelNodeInfo): String = flowType.node match + case BooleanTypeAnnotation => Defines.Boolean + case NumberTypeAnnotation => Defines.Number + case ObjectTypeAnnotation => Defines.Object + case StringTypeAnnotation => Defines.String + case SymbolTypeAnnotation => Defines.Symbol + case NumberLiteralTypeAnnotation => code(flowType.json) + case ArrayTypeAnnotation => code(flowType.json) + case BooleanLiteralTypeAnnotation => code(flowType.json) + case NullLiteralTypeAnnotation => code(flowType.json) + case StringLiteralTypeAnnotation => code(flowType.json) + case GenericTypeAnnotation => code(flowType.json("id")) + case ThisTypeAnnotation => + typeHintForThisExpression(Option(flowType)).headOption.getOrElse(Defines.Any); + case NullableTypeAnnotation => + typeForTypeAnnotation(createBabelNodeInfo(flowType.json(TypeAnnotationKey))) + case _ => Defines.Any + + private def typeForTsType(tsType: BabelNodeInfo): String = tsType.node match + case TSBooleanKeyword => Defines.Boolean + case TSBigIntKeyword => Defines.Number + case TSNullKeyword => Defines.Null + case TSNumberKeyword => Defines.Number + case TSObjectKeyword => Defines.Object + case TSStringKeyword => Defines.String + case TSSymbolKeyword => Defines.Symbol + case TSUnknownKeyword => Defines.Unknown + case TSVoidKeyword => Defines.Void + case TSUndefinedKeyword => Defines.Undefined + case TSNeverKeyword => Defines.Never + case TSIntrinsicKeyword => code(tsType.json) + case TSTypeReference => code(tsType.json) + case TSArrayType => code(tsType.json) + case TSThisType => + typeHintForThisExpression(Option(tsType)).headOption.getOrElse(Defines.Any) + case TSOptionalType => + typeForTypeAnnotation(createBabelNodeInfo(tsType.json(TypeAnnotationKey))) + case TSRestType => + typeForTypeAnnotation(createBabelNodeInfo(tsType.json(TypeAnnotationKey))) + case TSParenthesizedType => + typeForTypeAnnotation(createBabelNodeInfo(tsType.json(TypeAnnotationKey))) + case _ => Defines.Any + + private def typeForTypeAnnotation(typeAnnotation: BabelNodeInfo): String = + typeAnnotation.node match + case TypeAnnotation => + typeForFlowType(createBabelNodeInfo(typeAnnotation.json(TypeAnnotationKey))) + case TSTypeAnnotation => + typeForTsType(createBabelNodeInfo(typeAnnotation.json(TypeAnnotationKey))) + case _: FlowType => typeForFlowType(createBabelNodeInfo(typeAnnotation.json)) + case _: TSType => typeForTsType(createBabelNodeInfo(typeAnnotation.json)) + case _ => Defines.Any + + private def isStringType(tpe: String): Boolean = + tpe.startsWith("\"") && tpe.endsWith("\"") + + private def isNumberType(tpe: String): Boolean = + tpe.toDoubleOption.isDefined + + private def typeFromTypeMap(node: BabelNodeInfo): String = + pos(node.json).flatMap(parserResult.typeMap.get) match + case Some(value) if value.isEmpty => Defines.String + case Some(value) if value == "string" => Defines.String + case Some(value) if isStringType(value) => Defines.String + case Some(value) if value == "number" => Defines.Number + case Some(value) if isNumberType(value) => Defines.Number + case Some(value) if value == "null" => Defines.Null + case Some(value) if value == "boolean" => Defines.Boolean + case Some(value) if value == "any" => Defines.Any + case Some(value) if ImportMatcher.matcher(value).matches() => importToModule(value) + case Some(other) => + (TypeReplacements ++ ArrayReplacements).foldLeft(other) { case (typeStr, (m, r)) => + typeStr.replace(m, r) + } + case None => Defines.Any + + private def importToModule(value: String): String = + val matcher = ImportMatcher.matcher(value) + this.rootTypeDecl.headOption match + case Some(typeDecl) => typeDecl.fullName + case None if matcher.matches() => + matcher.group(2).stripSuffix(".js").concat(".js::program") + case None => value + + protected def typeFor(node: BabelNodeInfo): String = + val tpe = Seq(TypeAnnotationKey, ReturnTypeKey).find(hasKey(node.json, _)) match + case Some(key) => typeForTypeAnnotation(createBabelNodeInfo(node.json(key))) + case None => typeFromTypeMap(node) + registerType(tpe, tpe) + tpe + + protected def typeHintForThisExpression(node: Option[BabelNodeInfo] = None): Seq[String] = + dynamicInstanceTypeStack.headOption match + case Some(tpe) => Seq(tpe) + case None if node.isDefined => + typeFor(node.get) match + case t if t != Defines.Any && t != "this" => Seq(t) + case _ => rootTypeDecl.map(_.fullName).toSeq + case None => rootTypeDecl.map(_.fullName).toSeq end TypeHelper diff --git a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/datastructures/PendingReference.scala b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/datastructures/PendingReference.scala index 5e85539d..8139da93 100644 --- a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/datastructures/PendingReference.scala +++ b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/datastructures/PendingReference.scala @@ -8,14 +8,14 @@ case class PendingReference( stack: Option[ScopeElement] ): - def tryResolve(): Option[ResolvedReference] = - var foundVariableOption = Option.empty[NewNode] - val stackIterator = new ScopeElementIterator(stack) + def tryResolve(): Option[ResolvedReference] = + var foundVariableOption = Option.empty[NewNode] + val stackIterator = new ScopeElementIterator(stack) - while stackIterator.hasNext && foundVariableOption.isEmpty do - val scopeElement = stackIterator.next() - foundVariableOption = scopeElement.nameToVariableNode.get(variableName) + while stackIterator.hasNext && foundVariableOption.isEmpty do + val scopeElement = stackIterator.next() + foundVariableOption = scopeElement.nameToVariableNode.get(variableName) - foundVariableOption.map { variableNodeId => - ResolvedReference(variableNodeId, this) - } + foundVariableOption.map { variableNodeId => + ResolvedReference(variableNodeId, this) + } diff --git a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/datastructures/Scope.scala b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/datastructures/Scope.scala index c0958c7b..0c6094ae 100644 --- a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/datastructures/Scope.scala +++ b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/datastructures/Scope.scala @@ -8,101 +8,101 @@ import scala.collection.mutable */ class Scope: - private val pendingReferences: mutable.Buffer[PendingReference] = - mutable.ListBuffer.empty[PendingReference] - - private var stack = Option.empty[ScopeElement] - - def getScopeHead: Option[ScopeElement] = stack - - def isEmpty: Boolean = stack.isEmpty - - def pushNewMethodScope( - methodFullName: String, - name: String, - scopeNode: NewNode, - capturingRefId: Option[NewNode] - ): Unit = - stack = Option(new MethodScopeElement( - methodFullName, - capturingRefId, - name, - scopeNode, - surroundingScope = stack - )) - - def pushNewBlockScope(scopeNode: NewNode): Unit = - peek match - case Some(stackTop) => - stack = Option(new BlockScopeElement( - stackTop.subScopeCounter.toString, - scopeNode, - surroundingScope = stack - )) - stackTop.subScopeCounter += 1 - case None => - stack = Option(new BlockScopeElement("0", scopeNode, surroundingScope = stack)) - - private def peek: Option[ScopeElement] = - stack - - def popScope(): Unit = - stack = stack.get.surroundingScope - - def addVariable(variableName: String, variableNode: NewNode, scopeType: ScopeType): Unit = - addVariable(stack, variableName, variableNode, scopeType) - - def addVariableReference(variableName: String, referenceNode: NewNode): Unit = - pendingReferences.prepend(PendingReference(variableName, referenceNode, stack)) - - def resolve(unresolvedHandler: (NewNode, String) => (NewNode, ScopeType)) - : Iterator[ResolvedReference] = - pendingReferences.iterator.map { pendingReference => - val resolvedReferenceOption = pendingReference.tryResolve() - - resolvedReferenceOption.getOrElse { - val methodScopeNode = Scope.getEnclosingMethodScopeNode(pendingReference.stack) - val (newVariableNode, scopeType) = - unresolvedHandler(methodScopeNode, pendingReference.variableName) - addVariable( - pendingReference.stack, - pendingReference.variableName, - newVariableNode, - scopeType - ) - pendingReference.tryResolve().get - } + private val pendingReferences: mutable.Buffer[PendingReference] = + mutable.ListBuffer.empty[PendingReference] + + private var stack = Option.empty[ScopeElement] + + def getScopeHead: Option[ScopeElement] = stack + + def isEmpty: Boolean = stack.isEmpty + + def pushNewMethodScope( + methodFullName: String, + name: String, + scopeNode: NewNode, + capturingRefId: Option[NewNode] + ): Unit = + stack = Option(new MethodScopeElement( + methodFullName, + capturingRefId, + name, + scopeNode, + surroundingScope = stack + )) + + def pushNewBlockScope(scopeNode: NewNode): Unit = + peek match + case Some(stackTop) => + stack = Option(new BlockScopeElement( + stackTop.subScopeCounter.toString, + scopeNode, + surroundingScope = stack + )) + stackTop.subScopeCounter += 1 + case None => + stack = Option(new BlockScopeElement("0", scopeNode, surroundingScope = stack)) + + private def peek: Option[ScopeElement] = + stack + + def popScope(): Unit = + stack = stack.get.surroundingScope + + def addVariable(variableName: String, variableNode: NewNode, scopeType: ScopeType): Unit = + addVariable(stack, variableName, variableNode, scopeType) + + def addVariableReference(variableName: String, referenceNode: NewNode): Unit = + pendingReferences.prepend(PendingReference(variableName, referenceNode, stack)) + + def resolve(unresolvedHandler: (NewNode, String) => (NewNode, ScopeType)) + : Iterator[ResolvedReference] = + pendingReferences.iterator.map { pendingReference => + val resolvedReferenceOption = pendingReference.tryResolve() + + resolvedReferenceOption.getOrElse { + val methodScopeNode = Scope.getEnclosingMethodScopeNode(pendingReference.stack) + val (newVariableNode, scopeType) = + unresolvedHandler(methodScopeNode, pendingReference.variableName) + addVariable( + pendingReference.stack, + pendingReference.variableName, + newVariableNode, + scopeType + ) + pendingReference.tryResolve().get } - - private def addVariable( - stack: Option[ScopeElement], - variableName: String, - variableNode: NewNode, - scopeType: ScopeType - ): Unit = - val scopeToAddTo = scopeType match - case MethodScope => Scope.getEnclosingMethodScopeElement(stack) - case _ => stack.get - scopeToAddTo.addVariable(variableName, variableNode) + } + + private def addVariable( + stack: Option[ScopeElement], + variableName: String, + variableNode: NewNode, + scopeType: ScopeType + ): Unit = + val scopeToAddTo = scopeType match + case MethodScope => Scope.getEnclosingMethodScopeElement(stack) + case _ => stack.get + scopeToAddTo.addVariable(variableName, variableNode) end Scope object Scope: - private def getEnclosingMethodScopeNode(scopeHead: Option[ScopeElement]): NewNode = - getEnclosingMethodScopeElement(scopeHead).scopeNode + private def getEnclosingMethodScopeNode(scopeHead: Option[ScopeElement]): NewNode = + getEnclosingMethodScopeElement(scopeHead).scopeNode - def getEnclosingMethodScopeElement(scopeHead: Option[ScopeElement]): MethodScopeElement = - // There are no references outside of methods. Meaning we always find a MethodScope here. - new ScopeElementIterator(scopeHead) - .collectFirst { case methodScopeElement: MethodScopeElement => methodScopeElement } - .getOrElse(throw new RuntimeException("Cannot find method scope.")) + def getEnclosingMethodScopeElement(scopeHead: Option[ScopeElement]): MethodScopeElement = + // There are no references outside of methods. Meaning we always find a MethodScope here. + new ScopeElementIterator(scopeHead) + .collectFirst { case methodScopeElement: MethodScopeElement => methodScopeElement } + .getOrElse(throw new RuntimeException("Cannot find method scope.")) class ScopeElementIterator(start: Option[ScopeElement]) extends Iterator[ScopeElement]: - private var currentScopeElement = start + private var currentScopeElement = start - override def hasNext: Boolean = - currentScopeElement.isDefined + override def hasNext: Boolean = + currentScopeElement.isDefined - override def next(): ScopeElement = - val result = currentScopeElement.get - currentScopeElement = result.surroundingScope - result + override def next(): ScopeElement = + val result = currentScopeElement.get + currentScopeElement = result.surroundingScope + result diff --git a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/datastructures/ScopeElement.scala b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/datastructures/ScopeElement.scala index e1542cb2..a2ac5efa 100644 --- a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/datastructures/ScopeElement.scala +++ b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/datastructures/ScopeElement.scala @@ -11,11 +11,11 @@ abstract class ScopeElement( val scopeNode: NewNode, val surroundingScope: Option[ScopeElement] ): - var subScopeCounter: Int = 0 - val nameToVariableNode: mutable.Map[String, NewNode] = mutable.HashMap.empty + var subScopeCounter: Int = 0 + val nameToVariableNode: mutable.Map[String, NewNode] = mutable.HashMap.empty - def addVariable(variableName: String, variableNode: NewNode): Unit = - nameToVariableNode(variableName) = variableNode + def addVariable(variableName: String, variableNode: NewNode): Unit = + nameToVariableNode(variableName) = variableNode class MethodScopeElement( val methodFullName: String, diff --git a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/parser/BabelAst.scala b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/parser/BabelAst.scala index c7303086..ff9e7c5d 100644 --- a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/parser/BabelAst.scala +++ b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/parser/BabelAst.scala @@ -2,278 +2,279 @@ package io.appthreat.jssrc2cpg.parser object BabelAst: - private val QualifiedClassName: String = BabelAst.getClass.getName + private val QualifiedClassName: String = BabelAst.getClass.getName - def fromString(nodeName: String): BabelNode = - val clazz = Class.forName(s"$QualifiedClassName$nodeName$$") - clazz.getField("MODULE$").get(clazz).asInstanceOf[BabelNode] + def fromString(nodeName: String): BabelNode = + val clazz = Class.forName(s"$QualifiedClassName$nodeName$$") + clazz.getField("MODULE$").get(clazz).asInstanceOf[BabelNode] - // extracted from: - // https://github.com/babel/babel/blob/main/packages/babel-types/src/ast-types/generated/index.ts + // extracted from: + // https://github.com/babel/babel/blob/main/packages/babel-types/src/ast-types/generated/index.ts - sealed trait BabelNode: - override def toString: String = this.getClass.getSimpleName.stripSuffix("$") + sealed trait BabelNode: + override def toString: String = this.getClass.getSimpleName.stripSuffix("$") - sealed trait FlowType extends BabelNode + sealed trait FlowType extends BabelNode - sealed trait TSType extends BabelNode + sealed trait TSType extends BabelNode - sealed trait Expression extends BabelNode + sealed trait Expression extends BabelNode - sealed trait FunctionLike extends BabelNode + sealed trait FunctionLike extends BabelNode - object AnyTypeAnnotation extends FlowType - object ArgumentPlaceholder extends BabelNode - object ArrayExpression extends BabelNode - object ArrayPattern extends BabelNode - object ArrayTypeAnnotation extends FlowType - object ArrowFunctionExpression extends FunctionLike - object AssignmentExpression extends BabelNode - object AssignmentPattern extends BabelNode - object AwaitExpression extends BabelNode - object BigIntLiteral extends BabelNode - object BinaryExpression extends BabelNode - object BindExpression extends BabelNode - object BlockStatement extends BabelNode - object BooleanLiteral extends BabelNode - object BooleanLiteralTypeAnnotation extends FlowType - object BooleanTypeAnnotation extends FlowType - object BreakStatement extends BabelNode - object CallExpression extends Expression - object CatchClause extends BabelNode - object ClassAccessorProperty extends BabelNode - object ClassBody extends BabelNode - object ClassDeclaration extends BabelNode - object ClassExpression extends BabelNode - object ClassImplements extends BabelNode - object ClassMethod extends BabelNode - object ClassPrivateMethod extends BabelNode - object ClassPrivateProperty extends BabelNode - object ClassProperty extends BabelNode - object ConditionalExpression extends BabelNode - object ContinueStatement extends BabelNode - object DebuggerStatement extends BabelNode - object DecimalLiteral extends BabelNode - object DeclareClass extends BabelNode - object DeclareExportAllDeclaration extends BabelNode - object DeclareExportDeclaration extends BabelNode - object DeclareFunction extends BabelNode - object DeclareInterface extends BabelNode - object DeclareModule extends BabelNode - object DeclareModuleExports extends BabelNode - object DeclareOpaqueType extends BabelNode - object DeclareTypeAlias extends BabelNode - object DeclareVariable extends BabelNode - object DeclaredPredicate extends BabelNode - object Decorator extends BabelNode - object Directive extends BabelNode - object DirectiveLiteral extends BabelNode - object DoExpression extends BabelNode - object DoWhileStatement extends BabelNode - object EmptyStatement extends BabelNode - object EmptyTypeAnnotation extends FlowType - object EnumBooleanBody extends BabelNode - object EnumBooleanMember extends BabelNode - object EnumDeclaration extends BabelNode - object EnumDefaultedMember extends BabelNode - object EnumNumberBody extends BabelNode - object EnumNumberMember extends BabelNode - object EnumStringBody extends BabelNode - object EnumStringMember extends BabelNode - object EnumSymbolBody extends BabelNode - object ExistsTypeAnnotation extends FlowType - object ExportAllDeclaration extends BabelNode - object ExportDefaultDeclaration extends BabelNode - object ExportDefaultSpecifier extends BabelNode - object ExportNamedDeclaration extends BabelNode - object ExportNamespaceSpecifier extends BabelNode - object ExportSpecifier extends BabelNode - object ExpressionStatement extends BabelNode - object File extends BabelNode - object ForInStatement extends BabelNode - object ForOfStatement extends BabelNode - object ForStatement extends BabelNode - object FunctionDeclaration extends FunctionLike - object FunctionExpression extends FunctionLike - object FunctionTypeAnnotation extends FlowType - object FunctionTypeParam extends BabelNode - object GenericTypeAnnotation extends FlowType - object Identifier extends BabelNode - object IfStatement extends BabelNode - object Import extends BabelNode - object ImportAttribute extends BabelNode - object ImportDeclaration extends BabelNode - object ImportDefaultSpecifier extends BabelNode - object ImportNamespaceSpecifier extends BabelNode - object ImportSpecifier extends BabelNode - object IndexedAccessType extends FlowType - object InferredPredicate extends BabelNode - object InterfaceDeclaration extends BabelNode - object InterfaceExtends extends BabelNode - object InterfaceTypeAnnotation extends FlowType - object InterpreterDirective extends BabelNode - object IntersectionTypeAnnotation extends FlowType - object JSXAttribute extends BabelNode - object JSXClosingElement extends BabelNode - object JSXClosingFragment extends BabelNode - object JSXElement extends BabelNode - object JSXEmptyExpression extends BabelNode - object JSXExpressionContainer extends BabelNode - object JSXFragment extends BabelNode - object JSXIdentifier extends BabelNode - object JSXMemberExpression extends Expression - object JSXNamespacedName extends BabelNode - object JSXOpeningElement extends BabelNode - object JSXOpeningFragment extends BabelNode - object JSXSpreadAttribute extends BabelNode - object JSXSpreadChild extends BabelNode - object JSXText extends BabelNode - object LabeledStatement extends BabelNode - object LogicalExpression extends BabelNode - object MemberExpression extends Expression - object MetaProperty extends BabelNode - object MixedTypeAnnotation extends FlowType - object ModuleExpression extends BabelNode - object NewExpression extends Expression - object Noop extends BabelNode - object NullLiteral extends BabelNode - object NullLiteralTypeAnnotation extends FlowType - object NullableTypeAnnotation extends FlowType - object NumberLiteral extends BabelNode - object NumberLiteralTypeAnnotation extends FlowType - object NumberTypeAnnotation extends FlowType - object NumericLiteral extends BabelNode - object ObjectExpression extends BabelNode - object ObjectMethod extends BabelNode - object ObjectPattern extends BabelNode - object ObjectProperty extends BabelNode - object ObjectTypeAnnotation extends FlowType - object ObjectTypeCallProperty extends BabelNode - object ObjectTypeIndexer extends BabelNode - object ObjectTypeInternalSlot extends BabelNode - object ObjectTypeProperty extends BabelNode - object ObjectTypeSpreadProperty extends BabelNode - object OpaqueType extends BabelNode - object OptionalCallExpression extends Expression - object OptionalIndexedAccessType extends FlowType - object OptionalMemberExpression extends Expression - object ParenthesizedExpression extends BabelNode - object PipelineBareFunction extends BabelNode - object PipelinePrimaryTopicReference extends BabelNode - object PipelineTopicExpression extends Expression - object Placeholder extends BabelNode - object PrivateName extends BabelNode - object Program extends BabelNode - object QualifiedTypeIdentifier extends BabelNode - object RecordExpression extends Expression - object RegExpLiteral extends BabelNode - object RegexLiteral extends BabelNode - object RestElement extends BabelNode - object RestProperty extends BabelNode - object ReturnStatement extends BabelNode - object SequenceExpression extends BabelNode - object SpreadElement extends BabelNode - object SpreadProperty extends BabelNode - object StaticBlock extends BabelNode - object StringLiteral extends BabelNode - object StringLiteralTypeAnnotation extends FlowType - object StringTypeAnnotation extends FlowType - object Super extends BabelNode - object SwitchCase extends BabelNode - object SwitchStatement extends BabelNode - object SymbolTypeAnnotation extends FlowType - object TSAnyKeyword extends TSType - object TSArrayType extends TSType - object TSAsExpression extends Expression - object TSBigIntKeyword extends TSType - object TSBooleanKeyword extends TSType - object TSCallSignatureDeclaration extends BabelNode - object TSConditionalType extends TSType - object TSConstructSignatureDeclaration extends BabelNode - object TSConstructorType extends TSType - object TSDeclareFunction extends BabelNode - object TSDeclareMethod extends BabelNode - object TSEnumDeclaration extends BabelNode - object TSEnumMember extends BabelNode - object TSExportAssignment extends BabelNode - object TSExpressionWithTypeArguments extends TSType - object TSExternalModuleReference extends BabelNode - object TSFunctionType extends TSType - object TSImportEqualsDeclaration extends BabelNode - object TSImportType extends TSType - object TSIndexSignature extends BabelNode - object TSIndexedAccessType extends TSType - object TSInferType extends TSType - object TSInterfaceBody extends BabelNode - object TSInterfaceDeclaration extends BabelNode - object TSIntersectionType extends TSType - object TSIntrinsicKeyword extends TSType - object TSLiteralType extends TSType - object TSMappedType extends TSType - object TSMethodSignature extends BabelNode - object TSModuleBlock extends BabelNode - object TSModuleDeclaration extends BabelNode - object TSNamedTupleMember extends BabelNode - object TSNamespaceExportDeclaration extends BabelNode - object TSNeverKeyword extends TSType - object TSNonNullExpression extends Expression - object TSNullKeyword extends TSType - object TSNumberKeyword extends TSType - object TSObjectKeyword extends TSType - object TSOptionalType extends TSType - object TSParameterProperty extends Expression - object TSParenthesizedType extends TSType - object TSPropertySignature extends BabelNode - object TSQualifiedName extends BabelNode - object TSRestType extends TSType - object TSSatisfiesExpression extends Expression - object TSStringKeyword extends TSType - object TSSymbolKeyword extends TSType - object TSThisType extends TSType - object TSTupleType extends TSType - object TSTypeAliasDeclaration extends BabelNode - object TSTypeAnnotation extends FlowType - object TSTypeAssertion extends BabelNode - object TSTypeExpression extends TSType - object TSTypeLiteral extends TSType - object TSTypeOperator extends TSType - object TSTypeParameter extends TSType - object TSTypeParameterDeclaration extends BabelNode - object TSTypeParameterInstantiation extends BabelNode - object TSTypePredicate extends TSType - object TSTypeQuery extends TSType - object TSTypeReference extends TSType - object TSUndefinedKeyword extends TSType - object TSUnionType extends TSType - object TSUnknownKeyword extends TSType - object TSVoidKeyword extends TSType - object TaggedTemplateExpression extends BabelNode - object TemplateElement extends BabelNode - object TemplateLiteral extends BabelNode - object ThisExpression extends Expression - object ThisTypeAnnotation extends FlowType - object ThrowStatement extends BabelNode - object TopicReference extends BabelNode - object TryStatement extends BabelNode - object TupleExpression extends BabelNode - object TupleTypeAnnotation extends FlowType - object TypeAlias extends BabelNode - object TypeAnnotation extends FlowType - object TypeCastExpression extends BabelNode - object TSTypeCastExpression extends BabelNode - object TypeParameter extends BabelNode - object TypeParameterDeclaration extends BabelNode - object TypeParameterInstantiation extends BabelNode - object TypeofTypeAnnotation extends FlowType - object UnaryExpression extends BabelNode - object UnionTypeAnnotation extends FlowType - object UpdateExpression extends Expression - object V8IntrinsicIdentifier extends BabelNode - object VariableDeclaration extends BabelNode - object VariableDeclarator extends BabelNode - object Variance extends BabelNode - object VoidTypeAnnotation extends FlowType - object WhileStatement extends BabelNode - object WithStatement extends BabelNode - object YieldExpression extends BabelNode + object AnyTypeAnnotation extends FlowType + object ArgumentPlaceholder extends BabelNode + object ArrayExpression extends BabelNode + object ArrayPattern extends BabelNode + object ArrayTypeAnnotation extends FlowType + object ArrowFunctionExpression extends FunctionLike + object AssignmentExpression extends BabelNode + object AssignmentPattern extends BabelNode + object AwaitExpression extends BabelNode + object BigIntLiteral extends BabelNode + object BinaryExpression extends BabelNode + object BindExpression extends BabelNode + object BlockStatement extends BabelNode + object BooleanLiteral extends BabelNode + object BooleanLiteralTypeAnnotation extends FlowType + object BooleanTypeAnnotation extends FlowType + object BreakStatement extends BabelNode + object CallExpression extends Expression + object CatchClause extends BabelNode + object ClassAccessorProperty extends BabelNode + object ClassBody extends BabelNode + object ClassDeclaration extends BabelNode + object ClassExpression extends BabelNode + object ClassImplements extends BabelNode + object ClassMethod extends BabelNode + object ClassPrivateMethod extends BabelNode + object ClassPrivateProperty extends BabelNode + object ClassProperty extends BabelNode + object ConditionalExpression extends BabelNode + object ContinueStatement extends BabelNode + object DebuggerStatement extends BabelNode + object DecimalLiteral extends BabelNode + object DeclareClass extends BabelNode + object DeclareExportAllDeclaration extends BabelNode + object DeclareExportDeclaration extends BabelNode + object DeclareFunction extends BabelNode + object DeclareInterface extends BabelNode + object DeclareModule extends BabelNode + object DeclareModuleExports extends BabelNode + object DeclareOpaqueType extends BabelNode + object DeclareTypeAlias extends BabelNode + object DeclareVariable extends BabelNode + object DeclaredPredicate extends BabelNode + object Decorator extends BabelNode + object Directive extends BabelNode + object DirectiveLiteral extends BabelNode + object DoExpression extends BabelNode + object DoWhileStatement extends BabelNode + object EmptyStatement extends BabelNode + object EmptyTypeAnnotation extends FlowType + object EnumBooleanBody extends BabelNode + object EnumBooleanMember extends BabelNode + object EnumDeclaration extends BabelNode + object EnumDefaultedMember extends BabelNode + object EnumNumberBody extends BabelNode + object EnumNumberMember extends BabelNode + object EnumStringBody extends BabelNode + object EnumStringMember extends BabelNode + object EnumSymbolBody extends BabelNode + object ExistsTypeAnnotation extends FlowType + object ExportAllDeclaration extends BabelNode + object ExportDefaultDeclaration extends BabelNode + object ExportDefaultSpecifier extends BabelNode + object ExportNamedDeclaration extends BabelNode + object ExportNamespaceSpecifier extends BabelNode + object ExportSpecifier extends BabelNode + object ExpressionStatement extends BabelNode + object File extends BabelNode + object ForInStatement extends BabelNode + object ForOfStatement extends BabelNode + object ForStatement extends BabelNode + object FunctionDeclaration extends FunctionLike + object FunctionExpression extends FunctionLike + object FunctionTypeAnnotation extends FlowType + object FunctionTypeParam extends BabelNode + object GenericTypeAnnotation extends FlowType + object Identifier extends BabelNode + object IfStatement extends BabelNode + object Import extends BabelNode + object ImportAttribute extends BabelNode + object ImportDeclaration extends BabelNode + object ImportDefaultSpecifier extends BabelNode + object ImportNamespaceSpecifier extends BabelNode + object ImportSpecifier extends BabelNode + object IndexedAccessType extends FlowType + object InferredPredicate extends BabelNode + object InterfaceDeclaration extends BabelNode + object InterfaceExtends extends BabelNode + object InterfaceTypeAnnotation extends FlowType + object InterpreterDirective extends BabelNode + object IntersectionTypeAnnotation extends FlowType + object JSXAttribute extends BabelNode + object JSXClosingElement extends BabelNode + object JSXClosingFragment extends BabelNode + object JSXElement extends BabelNode + object JSXEmptyExpression extends BabelNode + object JSXExpressionContainer extends BabelNode + object JSXFragment extends BabelNode + object JSXIdentifier extends BabelNode + object JSXMemberExpression extends Expression + object JSXNamespacedName extends BabelNode + object JSXOpeningElement extends BabelNode + object JSXOpeningFragment extends BabelNode + object JSXSpreadAttribute extends BabelNode + object JSXSpreadChild extends BabelNode + object JSXText extends BabelNode + object LabeledStatement extends BabelNode + object LogicalExpression extends BabelNode + object MemberExpression extends Expression + object MetaProperty extends BabelNode + object MixedTypeAnnotation extends FlowType + object ModuleExpression extends BabelNode + object NewExpression extends Expression + object Noop extends BabelNode + object NullLiteral extends BabelNode + object NullLiteralTypeAnnotation extends FlowType + object NullableTypeAnnotation extends FlowType + object NumberLiteral extends BabelNode + object NumberLiteralTypeAnnotation extends FlowType + object NumberTypeAnnotation extends FlowType + object NumericLiteral extends BabelNode + object ObjectExpression extends BabelNode + object ObjectMethod extends BabelNode + object ObjectPattern extends BabelNode + object ObjectProperty extends BabelNode + object ObjectTypeAnnotation extends FlowType + object ObjectTypeCallProperty extends BabelNode + object ObjectTypeIndexer extends BabelNode + object ObjectTypeInternalSlot extends BabelNode + object ObjectTypeProperty extends BabelNode + object ObjectTypeSpreadProperty extends BabelNode + object OpaqueType extends BabelNode + object OptionalCallExpression extends Expression + object OptionalIndexedAccessType extends FlowType + object OptionalMemberExpression extends Expression + object ParenthesizedExpression extends BabelNode + object PipelineBareFunction extends BabelNode + object PipelinePrimaryTopicReference extends BabelNode + object PipelineTopicExpression extends Expression + object Placeholder extends BabelNode + object PrivateName extends BabelNode + object Program extends BabelNode + object QualifiedTypeIdentifier extends BabelNode + object RecordExpression extends Expression + object RegExpLiteral extends BabelNode + object RegexLiteral extends BabelNode + object RestElement extends BabelNode + object RestProperty extends BabelNode + object ReturnStatement extends BabelNode + object SequenceExpression extends BabelNode + object SpreadElement extends BabelNode + object SpreadProperty extends BabelNode + object StaticBlock extends BabelNode + object StringLiteral extends BabelNode + object StringLiteralTypeAnnotation extends FlowType + object StringTypeAnnotation extends FlowType + object Super extends BabelNode + object SwitchCase extends BabelNode + object SwitchStatement extends BabelNode + object SymbolTypeAnnotation extends FlowType + object TSAnyKeyword extends TSType + object TSArrayType extends TSType + object TSAsExpression extends Expression + object TSBigIntKeyword extends TSType + object TSBooleanKeyword extends TSType + object TSCallSignatureDeclaration extends BabelNode + object TSConditionalType extends TSType + object TSConstructSignatureDeclaration extends BabelNode + object TSConstructorType extends TSType + object TSDeclareFunction extends BabelNode + object TSDeclareMethod extends BabelNode + object TSEnumDeclaration extends BabelNode + object TSEnumMember extends BabelNode + object TSExportAssignment extends BabelNode + object TSExpressionWithTypeArguments extends TSType + object TSExternalModuleReference extends BabelNode + object TSFunctionType extends TSType + object TSImportEqualsDeclaration extends BabelNode + object TSImportType extends TSType + object TSIndexSignature extends BabelNode + object TSIndexedAccessType extends TSType + object TSInferType extends TSType + object TSInstantiationExpression extends Expression + object TSInterfaceBody extends BabelNode + object TSInterfaceDeclaration extends BabelNode + object TSIntersectionType extends TSType + object TSIntrinsicKeyword extends TSType + object TSLiteralType extends TSType + object TSMappedType extends TSType + object TSMethodSignature extends BabelNode + object TSModuleBlock extends BabelNode + object TSModuleDeclaration extends BabelNode + object TSNamedTupleMember extends BabelNode + object TSNamespaceExportDeclaration extends BabelNode + object TSNeverKeyword extends TSType + object TSNonNullExpression extends Expression + object TSNullKeyword extends TSType + object TSNumberKeyword extends TSType + object TSObjectKeyword extends TSType + object TSOptionalType extends TSType + object TSParameterProperty extends Expression + object TSParenthesizedType extends TSType + object TSPropertySignature extends BabelNode + object TSQualifiedName extends BabelNode + object TSRestType extends TSType + object TSSatisfiesExpression extends Expression + object TSStringKeyword extends TSType + object TSSymbolKeyword extends TSType + object TSThisType extends TSType + object TSTupleType extends TSType + object TSTypeAliasDeclaration extends BabelNode + object TSTypeAnnotation extends FlowType + object TSTypeAssertion extends BabelNode + object TSTypeExpression extends TSType + object TSTypeLiteral extends TSType + object TSTypeOperator extends TSType + object TSTypeParameter extends TSType + object TSTypeParameterDeclaration extends BabelNode + object TSTypeParameterInstantiation extends BabelNode + object TSTypePredicate extends TSType + object TSTypeQuery extends TSType + object TSTypeReference extends TSType + object TSUndefinedKeyword extends TSType + object TSUnionType extends TSType + object TSUnknownKeyword extends TSType + object TSVoidKeyword extends TSType + object TaggedTemplateExpression extends BabelNode + object TemplateElement extends BabelNode + object TemplateLiteral extends BabelNode + object ThisExpression extends Expression + object ThisTypeAnnotation extends FlowType + object ThrowStatement extends BabelNode + object TopicReference extends BabelNode + object TryStatement extends BabelNode + object TupleExpression extends BabelNode + object TupleTypeAnnotation extends FlowType + object TypeAlias extends BabelNode + object TypeAnnotation extends FlowType + object TypeCastExpression extends BabelNode + object TSTypeCastExpression extends BabelNode + object TypeParameter extends BabelNode + object TypeParameterDeclaration extends BabelNode + object TypeParameterInstantiation extends BabelNode + object TypeofTypeAnnotation extends FlowType + object UnaryExpression extends BabelNode + object UnionTypeAnnotation extends FlowType + object UpdateExpression extends Expression + object V8IntrinsicIdentifier extends BabelNode + object VariableDeclaration extends BabelNode + object VariableDeclarator extends BabelNode + object Variance extends BabelNode + object VoidTypeAnnotation extends FlowType + object WhileStatement extends BabelNode + object WithStatement extends BabelNode + object YieldExpression extends BabelNode end BabelAst diff --git a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/parser/BabelJsonParser.scala b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/parser/BabelJsonParser.scala index 5c25d36d..93067daf 100644 --- a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/parser/BabelJsonParser.scala +++ b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/parser/BabelJsonParser.scala @@ -8,27 +8,27 @@ import java.nio.file.Paths object BabelJsonParser: - case class ParseResult( - filename: String, - fullPath: String, - json: Value, - fileContent: String, - typeMap: Map[Int, String] - ) + case class ParseResult( + filename: String, + fullPath: String, + json: Value, + fileContent: String, + typeMap: Map[Int, String] + ) - def readFile(rootPath: Path, file: Path): ParseResult = - val typeMapPath = Paths.get(file.toString.replace(".json", ".typemap")) - val typeMap = if typeMapPath.toFile.exists() then - val typeMapJsonContent = IOUtils.readEntireFile(typeMapPath) - val typeMapJson = ujson.read(typeMapJsonContent) - typeMapJson.obj.map { case (k, v) => k.toInt -> v.str }.toMap - else - Map.empty[Int, String] + def readFile(rootPath: Path, file: Path): ParseResult = + val typeMapPath = Paths.get(file.toString.replace(".json", ".typemap")) + val typeMap = if typeMapPath.toFile.exists() then + val typeMapJsonContent = IOUtils.readEntireFile(typeMapPath) + val typeMapJson = ujson.read(typeMapJsonContent) + typeMapJson.obj.map { case (k, v) => k.toInt -> v.str }.toMap + else + Map.empty[Int, String] - val jsonContent = IOUtils.readEntireFile(file) - val json = ujson.read(jsonContent) - val filename = json("relativeName").str - val fullPath = Paths.get(rootPath.toString, filename) - val sourceFileContent = IOUtils.readEntireFile(fullPath) - ParseResult(filename, fullPath.toString, json, sourceFileContent, typeMap) + val jsonContent = IOUtils.readEntireFile(file) + val json = ujson.read(jsonContent) + val filename = json("relativeName").str + val fullPath = Paths.get(rootPath.toString, filename) + val sourceFileContent = IOUtils.readEntireFile(fullPath) + ParseResult(filename, fullPath.toString, json, sourceFileContent, typeMap) end BabelJsonParser diff --git a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/AstCreationPass.scala b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/AstCreationPass.scala index 769127f7..0878d06c 100644 --- a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/AstCreationPass.scala +++ b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/AstCreationPass.scala @@ -25,43 +25,43 @@ class AstCreationPass( implicit withSchemaValidation: ValidationMode ) extends ConcurrentWriterCpgPass[(String, String)](cpg): - private val logger: Logger = LoggerFactory.getLogger(classOf[AstCreationPass]) + private val logger: Logger = LoggerFactory.getLogger(classOf[AstCreationPass]) - private val usedTypes: ConcurrentHashMap[(String, String), Boolean] = new ConcurrentHashMap() + private val usedTypes: ConcurrentHashMap[(String, String), Boolean] = new ConcurrentHashMap() - override def generateParts(): Array[(String, String)] = astGenRunnerResult.parsedFiles.toArray + override def generateParts(): Array[(String, String)] = astGenRunnerResult.parsedFiles.toArray - def allUsedTypes(): List[(String, String)] = - usedTypes.keys().asScala.filterNot { case (typeName, _) => typeName == Defines.Any }.toList + def allUsedTypes(): List[(String, String)] = + usedTypes.keys().asScala.filterNot { case (typeName, _) => typeName == Defines.Any }.toList - override def finish(): Unit = - astGenRunnerResult.skippedFiles.foreach { skippedFile => - val (rootPath, fileName) = skippedFile - val filePath = Paths.get(rootPath, fileName) - val fileLOC = IOUtils.readLinesInFile(filePath).size - report.addReportInfo(fileName, fileLOC) - } + override def finish(): Unit = + astGenRunnerResult.skippedFiles.foreach { skippedFile => + val (rootPath, fileName) = skippedFile + val filePath = Paths.get(rootPath, fileName) + val fileLOC = IOUtils.readLinesInFile(filePath).size + report.addReportInfo(fileName, fileLOC) + } - override def runOnPart(diffGraph: DiffGraphBuilder, input: (String, String)): Unit = - val (rootPath, jsonFilename) = input - val ((gotCpg, filename), duration) = TimeUtils.time { - val parseResult = BabelJsonParser.readFile(Paths.get(rootPath), Paths.get(jsonFilename)) - val fileLOC = IOUtils.readLinesInFile(Paths.get(parseResult.fullPath)).size - report.addReportInfo(parseResult.filename, fileLOC, parsed = true) - Try { - val localDiff = new AstCreator(config, parseResult, usedTypes).createAst() - diffGraph.absorb(localDiff) - } match - case Failure(exception) => - logger.warn( - s"Failed to generate a CPG for: '${parseResult.fullPath}'", - exception - ) - (false, parseResult.filename) - case Success(_) => - logger.debug(s"Generated a CPG for: '${parseResult.fullPath}'") - (true, parseResult.filename) - } - report.updateReport(filename, cpg = gotCpg, duration) - end runOnPart + override def runOnPart(diffGraph: DiffGraphBuilder, input: (String, String)): Unit = + val (rootPath, jsonFilename) = input + val ((gotCpg, filename), duration) = TimeUtils.time { + val parseResult = BabelJsonParser.readFile(Paths.get(rootPath), Paths.get(jsonFilename)) + val fileLOC = IOUtils.readLinesInFile(Paths.get(parseResult.fullPath)).size + report.addReportInfo(parseResult.filename, fileLOC, parsed = true) + Try { + val localDiff = new AstCreator(config, parseResult, usedTypes).createAst() + diffGraph.absorb(localDiff) + } match + case Failure(exception) => + logger.warn( + s"Failed to generate a CPG for: '${parseResult.fullPath}'", + exception + ) + (false, parseResult.filename) + case Success(_) => + logger.debug(s"Generated a CPG for: '${parseResult.fullPath}'") + (true, parseResult.filename) + } + report.updateReport(filename, cpg = gotCpg, duration) + end runOnPart end AstCreationPass diff --git a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/BuiltinTypesPass.scala b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/BuiltinTypesPass.scala index 951ed10b..49650a5c 100644 --- a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/BuiltinTypesPass.scala +++ b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/BuiltinTypesPass.scala @@ -7,33 +7,33 @@ import io.shiftleft.passes.CpgPass class BuiltinTypesPass(cpg: Cpg) extends CpgPass(cpg): - override def run(diffGraph: DiffGraphBuilder): Unit = - val namespaceBlock = NewNamespaceBlock() - .name(Defines.GlobalNamespace) - .fullName(Defines.GlobalNamespace) - .order(0) - .filename("builtintypes") + override def run(diffGraph: DiffGraphBuilder): Unit = + val namespaceBlock = NewNamespaceBlock() + .name(Defines.GlobalNamespace) + .fullName(Defines.GlobalNamespace) + .order(0) + .filename("builtintypes") - diffGraph.addNode(namespaceBlock) + diffGraph.addNode(namespaceBlock) - Defines.JsTypes.zipWithIndex.map { case (typeName: String, index) => - val tpe = NewType() - .name(typeName) - .fullName(typeName) - .typeDeclFullName(typeName) - diffGraph.addNode(tpe) + Defines.JsTypes.zipWithIndex.map { case (typeName: String, index) => + val tpe = NewType() + .name(typeName) + .fullName(typeName) + .typeDeclFullName(typeName) + diffGraph.addNode(tpe) - val typeDecl = NewTypeDecl() - .name(typeName) - .fullName(typeName) - .isExternal(true) - .astParentType(NodeTypes.NAMESPACE_BLOCK) - .astParentFullName(Defines.GlobalNamespace) - .order(index + 1) - .filename("builtintypes") + val typeDecl = NewTypeDecl() + .name(typeName) + .fullName(typeName) + .isExternal(true) + .astParentType(NodeTypes.NAMESPACE_BLOCK) + .astParentFullName(Defines.GlobalNamespace) + .order(index + 1) + .filename("builtintypes") - diffGraph.addNode(typeDecl) - diffGraph.addEdge(namespaceBlock, typeDecl, EdgeTypes.AST) - } - end run + diffGraph.addNode(typeDecl) + diffGraph.addEdge(namespaceBlock, typeDecl, EdgeTypes.AST) + } + end run end BuiltinTypesPass diff --git a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/ConfigPass.scala b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/ConfigPass.scala index 5a90a26f..803f5a72 100644 --- a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/ConfigPass.scala +++ b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/ConfigPass.scala @@ -13,37 +13,37 @@ import org.slf4j.{Logger, LoggerFactory} class ConfigPass(cpg: Cpg, config: Config, report: Report = new Report()) extends ConcurrentWriterCpgPass[File](cpg): - private val logger: Logger = LoggerFactory.getLogger(getClass) - - protected val allExtensions: Set[String] = Set(".json", ".js", ".vue", ".html", ".pug") - protected val selectedExtensions: Set[String] = - Set(".json", ".config.js", ".conf.js", ".vue", ".html", ".pug") - - override def generateParts(): Array[File] = - configFiles(config, allExtensions).toArray - - protected def fileContent(file: File): Seq[String] = - IOUtils.readLinesInFile(file.path) - - protected def configFiles(config: Config, extensions: Set[String]): Seq[File] = - SourceFiles - .determine(config.inputPath, extensions) - .filterNot(_.contains(Defines.NodeModulesFolder)) - .filter(f => selectedExtensions.exists(f.endsWith)) - .map(File(_)) - - override def runOnPart(diffGraph: DiffGraphBuilder, file: File): Unit = - val path = File(config.inputPath).path.toAbsolutePath.relativize(file.path).toString - logger.debug(s"Adding file '$path' as config.") - val (gotCpg, duration) = TimeUtils.time { - val localDiff = new DiffGraphBuilder - val content = fileContent(file) - val loc = content.size - val configNode = NewConfigFile().name(path).content(content.mkString("\n")) - report.addReportInfo(path, loc, parsed = true) - localDiff.addNode(configNode) - localDiff - } - diffGraph.absorb(gotCpg) - report.updateReport(path, cpg = true, duration) + private val logger: Logger = LoggerFactory.getLogger(getClass) + + protected val allExtensions: Set[String] = Set(".json", ".js", ".vue", ".html", ".pug") + protected val selectedExtensions: Set[String] = + Set(".json", ".config.js", ".conf.js", ".vue", ".html", ".pug") + + override def generateParts(): Array[File] = + configFiles(config, allExtensions).toArray + + protected def fileContent(file: File): Seq[String] = + IOUtils.readLinesInFile(file.path) + + protected def configFiles(config: Config, extensions: Set[String]): Seq[File] = + SourceFiles + .determine(config.inputPath, extensions) + .filterNot(_.contains(Defines.NodeModulesFolder)) + .filter(f => selectedExtensions.exists(f.endsWith)) + .map(File(_)) + + override def runOnPart(diffGraph: DiffGraphBuilder, file: File): Unit = + val path = File(config.inputPath).path.toAbsolutePath.relativize(file.path).toString + logger.debug(s"Adding file '$path' as config.") + val (gotCpg, duration) = TimeUtils.time { + val localDiff = new DiffGraphBuilder + val content = fileContent(file) + val loc = content.size + val configNode = NewConfigFile().name(path).content(content.mkString("\n")) + report.addReportInfo(path, loc, parsed = true) + localDiff.addNode(configNode) + localDiff + } + diffGraph.absorb(gotCpg) + report.updateReport(path, cpg = true, duration) end ConfigPass diff --git a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/ConstClosurePass.scala b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/ConstClosurePass.scala index 66d66332..9ba86001 100644 --- a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/ConstClosurePass.scala +++ b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/ConstClosurePass.scala @@ -11,59 +11,59 @@ import io.shiftleft.semanticcpg.language.* */ class ConstClosurePass(cpg: Cpg) extends CpgPass(cpg): - // Keeps track of how many times an identifier has been on the LHS of an assignment, by name - private lazy val identifiersAssignedCount: Map[String, Int] = - cpg.assignment.target.collectAll[Identifier].name.groupCount + // Keeps track of how many times an identifier has been on the LHS of an assignment, by name + private lazy val identifiersAssignedCount: Map[String, Int] = + cpg.assignment.target.collectAll[Identifier].name.groupCount - override def run(diffGraph: DiffGraphBuilder): Unit = - handleConstClosures(diffGraph) - handleClosuresDefinedAtExport(diffGraph) - handleClosuresAssignedToMutableVar(diffGraph) + override def run(diffGraph: DiffGraphBuilder): Unit = + handleConstClosures(diffGraph) + handleClosuresDefinedAtExport(diffGraph) + handleClosuresAssignedToMutableVar(diffGraph) - private def handleConstClosures(diffGraph: DiffGraphBuilder): Unit = - for - assignment <- cpg.assignment - name <- assignment.filter(_.code.startsWith("const ")).target.isIdentifier.name - methodRef <- assignment.start.source.isMethodRef - method <- methodRef.referencedMethod - enclosingMethod <- assignment.start.method.fullName - do - updateClosures(diffGraph, method, methodRef, enclosingMethod, name) + private def handleConstClosures(diffGraph: DiffGraphBuilder): Unit = + for + assignment <- cpg.assignment + name <- assignment.filter(_.code.startsWith("const ")).target.isIdentifier.name + methodRef <- assignment.start.source.isMethodRef + method <- methodRef.referencedMethod + enclosingMethod <- assignment.start.method.fullName + do + updateClosures(diffGraph, method, methodRef, enclosingMethod, name) - private def handleClosuresDefinedAtExport(diffGraph: DiffGraphBuilder): Unit = - for - assignment <- cpg.assignment - name <- assignment.filter( - _.code.startsWith("export") - ).target.isCall.argument.isFieldIdentifier.canonicalName.l - methodRef <- assignment.start.source.ast.isMethodRef - method <- methodRef.referencedMethod - enclosingMethod <- assignment.start.method.fullName - do - updateClosures(diffGraph, method, methodRef, enclosingMethod, name) + private def handleClosuresDefinedAtExport(diffGraph: DiffGraphBuilder): Unit = + for + assignment <- cpg.assignment + name <- assignment.filter( + _.code.startsWith("export") + ).target.isCall.argument.isFieldIdentifier.canonicalName.l + methodRef <- assignment.start.source.ast.isMethodRef + method <- methodRef.referencedMethod + enclosingMethod <- assignment.start.method.fullName + do + updateClosures(diffGraph, method, methodRef, enclosingMethod, name) - private def handleClosuresAssignedToMutableVar(diffGraph: DiffGraphBuilder): Unit = - // Handle closures assigned to mutable variables - for - assignment <- cpg.assignment - name <- assignment.start.code("^(var|let) .*").target.isIdentifier.name - methodRef <- assignment.start.source.ast.isMethodRef - method <- methodRef.referencedMethod - enclosingMethod <- assignment.start.method.fullName - do - // Conservatively update closures, i.e, if we only find 1 assignment where this variable is on the LHS - if identifiersAssignedCount.getOrElse(name, -1) == 1 then - updateClosures(diffGraph, method, methodRef, enclosingMethod, name) + private def handleClosuresAssignedToMutableVar(diffGraph: DiffGraphBuilder): Unit = + // Handle closures assigned to mutable variables + for + assignment <- cpg.assignment + name <- assignment.start.code("^(var|let) .*").target.isIdentifier.name + methodRef <- assignment.start.source.ast.isMethodRef + method <- methodRef.referencedMethod + enclosingMethod <- assignment.start.method.fullName + do + // Conservatively update closures, i.e, if we only find 1 assignment where this variable is on the LHS + if identifiersAssignedCount.getOrElse(name, -1) == 1 then + updateClosures(diffGraph, method, methodRef, enclosingMethod, name) - private def updateClosures( - diffGraph: DiffGraphBuilder, - method: Method, - methodRef: MethodRef, - enclosingMethod: String, - name: String - ): Unit = - val fullName = s"$enclosingMethod:$name" - diffGraph.setNodeProperty(methodRef, PropertyNames.METHOD_FULL_NAME, fullName) - diffGraph.setNodeProperty(method, PropertyNames.NAME, name) - diffGraph.setNodeProperty(method, PropertyNames.FULL_NAME, fullName) + private def updateClosures( + diffGraph: DiffGraphBuilder, + method: Method, + methodRef: MethodRef, + enclosingMethod: String, + name: String + ): Unit = + val fullName = s"$enclosingMethod:$name" + diffGraph.setNodeProperty(methodRef, PropertyNames.METHOD_FULL_NAME, fullName) + diffGraph.setNodeProperty(method, PropertyNames.NAME, name) + diffGraph.setNodeProperty(method, PropertyNames.FULL_NAME, fullName) end ConstClosurePass diff --git a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/Defines.scala b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/Defines.scala index bd24e791..e435f8cf 100644 --- a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/Defines.scala +++ b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/Defines.scala @@ -3,38 +3,38 @@ package io.appthreat.jssrc2cpg.passes import io.shiftleft.semanticcpg.language.types.structure.NamespaceTraversal object Defines: - val Any: String = "ANY" - val Number: String = "__ecma.Number" - val String: String = "__ecma.String" - val Boolean: String = "__ecma.Boolean" - val Null: String = "__ecma.Null" - val Math: String = "__ecma.Math" - val Symbol: String = "__ecma.Symbol" - val Console: String = "__whatwg.console" - val Object: String = "object" - val BigInt: String = "bigint" - val Unknown: String = "unknown" - val Void: String = "void" - val Never: String = "never" - val Undefined: String = "undefined" - val NodeModulesFolder: String = "node_modules" - val GlobalNamespace: String = NamespaceTraversal.globalNamespaceName + val Any: String = "ANY" + val Number: String = "__ecma.Number" + val String: String = "__ecma.String" + val Boolean: String = "__ecma.Boolean" + val Null: String = "__ecma.Null" + val Math: String = "__ecma.Math" + val Symbol: String = "__ecma.Symbol" + val Console: String = "__whatwg.console" + val Object: String = "object" + val BigInt: String = "bigint" + val Unknown: String = "unknown" + val Void: String = "void" + val Never: String = "never" + val Undefined: String = "undefined" + val NodeModulesFolder: String = "node_modules" + val GlobalNamespace: String = NamespaceTraversal.globalNamespaceName - val JsTypes: List[String] = - List( - Any, - Number, - String, - Boolean, - Null, - Math, - Symbol, - Console, - Object, - BigInt, - Unknown, - Never, - Void, - Undefined - ) + val JsTypes: List[String] = + List( + Any, + Number, + String, + Boolean, + Null, + Math, + Symbol, + Console, + Object, + BigInt, + Unknown, + Never, + Void, + Undefined + ) end Defines diff --git a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/DependenciesPass.scala b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/DependenciesPass.scala index cd0eabb8..77b20e7b 100644 --- a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/DependenciesPass.scala +++ b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/DependenciesPass.scala @@ -13,24 +13,24 @@ import java.nio.file.Paths */ class DependenciesPass(cpg: Cpg, config: Config) extends CpgPass(cpg): - override def run(diffGraph: DiffGraphBuilder): Unit = - val packagesJsons = SourceFiles - .determine(config.inputPath, Set(".json")) - .filterNot(_.contains(Defines.NodeModulesFolder)) - .filter(f => - f.endsWith(PackageJsonParser.PackageJsonFilename) || f.endsWith( - PackageJsonParser.PackageJsonLockFilename - ) + override def run(diffGraph: DiffGraphBuilder): Unit = + val packagesJsons = SourceFiles + .determine(config.inputPath, Set(".json")) + .filterNot(_.contains(Defines.NodeModulesFolder)) + .filter(f => + f.endsWith(PackageJsonParser.PackageJsonFilename) || f.endsWith( + PackageJsonParser.PackageJsonLockFilename ) + ) - val dependencies: Map[String, String] = - packagesJsons.flatMap(p => PackageJsonParser.dependencies(Paths.get(p))).toMap + val dependencies: Map[String, String] = + packagesJsons.flatMap(p => PackageJsonParser.dependencies(Paths.get(p))).toMap - dependencies.foreach { case (name, version) => - val dep = NewDependency() - .name(name) - .version(version) - diffGraph.addNode(dep) - } - end run + dependencies.foreach { case (name, version) => + val dep = NewDependency() + .name(name) + .version(version) + diffGraph.addNode(dep) + } + end run end DependenciesPass diff --git a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/EcmaBuiltins.scala b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/EcmaBuiltins.scala index 7956fa95..473becbf 100644 --- a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/EcmaBuiltins.scala +++ b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/EcmaBuiltins.scala @@ -1,4 +1,4 @@ package io.appthreat.jssrc2cpg.passes object EcmaBuiltins: - val arrayFactory = "__ecma.Array.factory" + val arrayFactory = "__ecma.Array.factory" diff --git a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/GlobalBuiltins.scala b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/GlobalBuiltins.scala index ace5c6e2..11a10af8 100644 --- a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/GlobalBuiltins.scala +++ b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/GlobalBuiltins.scala @@ -2,1092 +2,1092 @@ package io.appthreat.jssrc2cpg.passes object GlobalBuiltins: - val builtins: Set[String] = Set( - "AggregateError", - "Array", - "ArrayBuffer", - "AsyncFunction", - "AsyncGenerator", - "AsyncGeneratorFunction", - "AsyncIterator", - "Atomics", - "BigInt", - "BigInt64Array", - "BigUint64Array", - "Boolean", - "Buffer.from", - "DataView", - "Date", - "Error", - "EvalError", - "FinalizationRegistry", - "Float32Array", - "Float64Array", - "Function", - "Generator", - "GeneratorFunction", - "HTMLImageElement", - "Iterator", - "Infinity", - "Int16Array", - "Int32Array", - "Int8Array", - "InternalError", - "Intl", - "Intl.Collator", - "Intl.DateTimeFormat", - "Intl.DisplayNames", - "Intl.DurationFormat", - "Intl.ListFormat", - "Intl.Locale", - "Intl.NumberFormat", - "Intl.PluralRules", - "Intl.RelativeTimeFormat", - "Intl.Segmenter", - "JSON", - "JSON.parse", - "JSON.stringify", - "Map", - "Math", - "NaN", - "Number", - "Number.isFinite", - "Number.isInteger", - "Number.isNaN", - "Number.isSafeInteger", - "Number.parseFloat", - "Number.parseInt", - "Number.prototype.toExponential", - "Number.prototype.toFixed", - "Number.prototype.toLocaleString", - "Number.prototype.toPrecision", - "Number.prototype.toSource", - "Number.prototype.toString", - "Number.prototype.valueOf", - "Object", - "Object.assign", - "Object.create", - "Object.defineProperties", - "Object.defineProperty", - "Object.entries", - "Object.freeze", - "Object.fromEntries", - "Object.getOwnPropertyDescriptor", - "Object.getOwnPropertyDescriptors", - "Object.getOwnPropertyNames", - "Object.getOwnPropertySymbols", - "Object.getPrototypeOf", - "Object.is", - "Object.isExtensible", - "Object.isFrozen", - "Object.isSealed", - "Object.keys", - "Object.preventExtensions", - "Object.prototype.__defineGetter__", - "Object.prototype.__defineSetter__", - "Object.prototype.__lookupGetter__", - "Object.prototype.__lookupSetter__", - "Object.prototype.hasOwnProperty", - "Object.prototype.isPrototypeOf", - "Object.prototype.propertyIsEnumerable", - "Object.prototype.toLocaleString", - "Object.prototype.toSource", - "Object.prototype.toString", - "Object.prototype.valueOf", - "Object.seal", - "Object.setPrototypeOf", - "Object.values", - "Promise", - "Promise.all", - "Promise.allSettled", - "Promise.any", - "Promise.race", - "Promise.reject", - "Promise.resolve", - "Proxy", - "RangeError", - "ReferenceError", - "Reflect", - "RegExp", - "Set", - "SharedArrayBuffer", - "String", - "Symbol", - "SyntaxError", - "TypeError", - "TypedArray", - "URIError", - "Uint16Array", - "Uint32Array", - "Uint8Array", - "Uint8ClampedArray", - "WeakMap", - "WeakRef", - "WeakSet", - "decodeURI", - "decodeURIComponent", - "encodeURI", - "encodeURIComponent", - "escape", - "eval", - "eval", - "fetch", - "globalThis", - "isFinite", - "isNaN", - "localStorage.setItem", - "parseFloat", - "parseInt", - "undefined", - "unescape", - "uneval", - "AbortController", - "AbortSignal", - "AbsoluteOrientationSensor", - "AbstractRange", - "Accelerometer", - "AesCbcParams", - "AesCtrParams", - "AesGcmParams", - "AesKeyGenParams", - "AmbientLightSensor", - "AnalyserNode", - "ANGLE_instanced_arrays", - "Animation", - "AnimationEffect", - "AnimationEvent", - "AnimationPlaybackEvent", - "AnimationTimeline", - "Attr", - "AudioBuffer", - "AudioBufferSourceNode", - "AudioContext", - "AudioData", - "AudioDecoder", - "AudioDestinationNode", - "AudioEncoder", - "AudioListener", - "AudioNode", - "AudioParam", - "AudioParamDescriptor", - "AudioParamMap", - "AudioProcessingEvent", - "AudioScheduledSourceNode", - "AudioSinkInfo", - "AudioTrack", - "AudioTrackList", - "AudioWorklet", - "AudioWorkletGlobalScope", - "AudioWorkletNode", - "AudioWorkletProcessor", - "AuthenticatorAssertionResponse", - "AuthenticatorAttestationResponse", - "AuthenticatorResponse", - "BackgroundFetchEvent", - "BackgroundFetchManager", - "BackgroundFetchRecord", - "BackgroundFetchRegistration", - "BackgroundFetchUpdateUIEvent", - "BarcodeDetector", - "BarProp", - "BaseAudioContext", - "BatteryManager", - "BeforeInstallPromptEvent", - "BeforeUnloadEvent", - "BiquadFilterNode", - "Blob", - "BlobEvent", - "Bluetooth", - "BluetoothCharacteristicProperties", - "BluetoothDevice", - "BluetoothRemoteGATTCharacteristic", - "BluetoothRemoteGATTDescriptor", - "BluetoothRemoteGATTServer", - "BluetoothRemoteGATTService", - "BluetoothUUID", - "BroadcastChannel", - "ByteLengthQueuingStrategy", - "Cache", - "CacheStorage", - "CanMakePaymentEvent", - "CanvasCaptureMediaStreamTrack", - "CanvasGradient", - "CanvasPattern", - "CanvasRenderingContext2D", - "CaptureController", - "CaretPosition", - "CDATASection", - "ChannelMergerNode", - "ChannelSplitterNode", - "CharacterData", - "Client", - "Clients", - "Clipboard", - "ClipboardEvent", - "ClipboardItem", - "CloseEvent", - "Comment", - "CompositionEvent", - "CompressionStream", - "console", - "ConstantSourceNode", - "ContactAddress", - "ContactsManager", - "ContentIndex", - "ContentIndexEvent", - "ContentVisibilityAutoStateChangeEvent", - "ConvolverNode", - "CookieChangeEvent", - "CookieStore", - "CookieStoreManager", - "CountQueuingStrategy", - "Credential", - "CredentialsContainer", - "Crypto", - "CryptoKey", - "CryptoKeyPair", - "CSPViolationReportBody", - "CSS", - "CSSAnimation", - "CSSConditionRule", - "CSSContainerRule", - "CSSCounterStyleRule", - "CSSFontFaceRule", - "CSSFontFeatureValuesRule", - "CSSFontPaletteValuesRule", - "CSSGroupingRule", - "CSSImageValue", - "CSSImportRule", - "CSSKeyframeRule", - "CSSKeyframesRule", - "CSSKeywordValue", - "CSSLayerBlockRule", - "CSSLayerStatementRule", - "CSSMathInvert", - "CSSMathMax", - "CSSMathMin", - "CSSMathNegate", - "CSSMathProduct", - "CSSMathSum", - "CSSMathValue", - "CSSMatrixComponent", - "CSSMediaRule", - "CSSNamespaceRule", - "CSSNumericArray", - "CSSNumericValue", - "CSSPageRule", - "CSSPerspective", - "CSSPositionValue", - "CSSPrimitiveValue", - "CSSPropertyRule", - "CSSPseudoElement", - "CSSRotate", - "CSSRule", - "CSSRuleList", - "CSSScale", - "CSSSkew", - "CSSSkewX", - "CSSSkewY", - "CSSStyleDeclaration", - "CSSStyleRule", - "CSSStyleSheet", - "CSSStyleValue", - "CSSSupportsRule", - "CSSTransformComponent", - "CSSTransformValue", - "CSSTransition", - "CSSTranslate", - "CSSUnitValue", - "CSSUnparsedValue", - "CSSValue", - "CSSValueList", - "CSSVariableReferenceValue", - "CustomElementRegistry", - "CustomEvent", - "CustomStateSet", - "DataTransfer", - "DataTransferItem", - "DataTransferItemList", - "DecompressionStream", - "DedicatedWorkerGlobalScope", - "DelayNode", - "DeprecationReportBody", - "DeviceMotionEvent", - "DeviceMotionEventAcceleration", - "DeviceMotionEventRotationRate", - "DeviceOrientationEvent", - "DirectoryEntrySync", - "DirectoryReaderSync", - "Document", - "DocumentFragment", - "DocumentTimeline", - "DocumentType", - "DOMError", - "DOMException", - "DOMHighResTimeStamp", - "DOMImplementation", - "DOMMatrix (WebKitCSSMatrix)", - "DOMMatrixReadOnly", - "DOMParser", - "DOMPoint", - "DOMPointReadOnly", - "DOMQuad", - "DOMRect", - "DOMRectReadOnly", - "DOMStringList", - "DOMStringMap", - "DOMTokenList", - "DragEvent", - "DynamicsCompressorNode", - "EcdhKeyDeriveParams", - "EcdsaParams", - "EcKeyGenParams", - "EcKeyImportParams", - "Element", - "ElementInternals", - "EncodedAudioChunk", - "EncodedVideoChunk", - "ErrorEvent", - "Event", - "EventCounts", - "EventSource", - "EventTarget", - "ExtendableCookieChangeEvent", - "ExtendableEvent", - "ExtendableMessageEvent", - "EyeDropper", - "FeaturePolicy", - "FederatedCredential", - "FetchEvent", - "File", - "FileEntrySync", - "FileList", - "FileReader", - "FileReaderSync", - "FileSystem", - "FileSystemDirectoryEntry", - "FileSystemDirectoryHandle", - "FileSystemDirectoryReader", - "FileSystemEntry", - "FileSystemFileEntry", - "FileSystemFileHandle", - "FileSystemHandle", - "FileSystemSync", - "FileSystemSyncAccessHandle", - "FileSystemWritableFileStream", - "FocusEvent", - "FontData", - "FontFace", - "FontFaceSet", - "FontFaceSetLoadEvent", - "FormData", - "FormDataEvent", - "FragmentDirective", - "GainNode", - "Gamepad", - "GamepadButton", - "GamepadEvent", - "GamepadHapticActuator", - "GamepadPose", - "Geolocation", - "GeolocationCoordinates", - "GeolocationPosition", - "GeolocationPositionError", - "GestureEvent", - "GPU", - "GPUAdapter", - "GPUAdapterInfo", - "GPUBindGroup", - "GPUBindGroupLayout", - "GPUBuffer", - "GPUCanvasContext", - "GPUCommandBuffer", - "GPUCommandEncoder", - "GPUCompilationInfo", - "GPUCompilationMessage", - "GPUComputePassEncoder", - "GPUComputePipeline", - "GPUDevice", - "GPUDeviceLostInfo", - "GPUError", - "GPUExternalTexture", - "GPUInternalError", - "GPUOutOfMemoryError", - "GPUPipelineError", - "GPUPipelineLayout", - "GPUQuerySet", - "GPUQueue", - "GPURenderBundle", - "GPURenderBundleEncoder", - "GPURenderPassEncoder", - "GPURenderPipeline", - "GPUSampler", - "GPUShaderModule", - "GPUSupportedFeatures", - "GPUSupportedLimits", - "GPUTexture", - "GPUTextureView", - "GPUUncapturedErrorEvent", - "GPUValidationError", - "GravitySensor", - "Gyroscope", - "HashChangeEvent", - "Headers", - "HID", - "HIDConnectionEvent", - "HIDDevice", - "HIDInputReportEvent", - "Highlight", - "HighlightRegistry", - "History", - "HkdfParams", - "HmacImportParams", - "HmacKeyGenParams", - "HMDVRDevice", - "HTMLAnchorElement", - "HTMLAreaElement", - "HTMLAudioElement", - "HTMLBaseElement", - "HTMLBodyElement", - "HTMLBRElement", - "HTMLButtonElement", - "HTMLCanvasElement", - "HTMLCollection", - "HTMLDataElement", - "HTMLDataListElement", - "HTMLDetailsElement", - "HTMLDialogElement", - "HTMLDivElement", - "HTMLDListElement", - "HTMLDocument", - "HTMLElement", - "HTMLEmbedElement", - "HTMLFieldSetElement", - "HTMLFontElement", - "HTMLFormControlsCollection", - "HTMLFormElement", - "HTMLFrameSetElement", - "HTMLHeadElement", - "HTMLHeadingElement", - "HTMLHRElement", - "HTMLHtmlElement", - "HTMLIFrameElement", - "HTMLImageElement", - "HTMLInputElement", - "HTMLLabelElement", - "HTMLLegendElement", - "HTMLLIElement", - "HTMLLinkElement", - "HTMLMapElement", - "HTMLMarqueeElement", - "HTMLMediaElement", - "HTMLMenuElement", - "HTMLMenuItemElement", - "HTMLMetaElement", - "HTMLMeterElement", - "HTMLModElement", - "HTMLObjectElement", - "HTMLOListElement", - "HTMLOptGroupElement", - "HTMLOptionElement", - "HTMLOptionsCollection", - "HTMLOutputElement", - "HTMLParagraphElement", - "HTMLParamElement", - "HTMLPictureElement", - "HTMLPreElement", - "HTMLProgressElement", - "HTMLQuoteElement", - "HTMLScriptElement", - "HTMLSelectElement", - "HTMLSlotElement", - "HTMLSourceElement", - "HTMLSpanElement", - "HTMLStyleElement", - "HTMLTableCaptionElement", - "HTMLTableCellElement", - "HTMLTableColElement", - "HTMLTableElement", - "HTMLTableRowElement", - "HTMLTableSectionElement", - "HTMLTemplateElement", - "HTMLTextAreaElement", - "HTMLTimeElement", - "HTMLTitleElement", - "HTMLTrackElement", - "HTMLUListElement", - "HTMLUnknownElement", - "HTMLVideoElement", - "IDBCursor", - "IDBCursorWithValue", - "IDBDatabase", - "IDBFactory", - "IDBIndex", - "IDBKeyRange", - "IDBLocaleAwareKeyRange", - "IDBObjectStore", - "IDBOpenDBRequest", - "IDBRequest", - "IDBTransaction", - "IDBVersionChangeEvent", - "IdentityCredential", - "IdleDeadline", - "IdleDetector", - "IIRFilterNode", - "ImageBitmap", - "ImageBitmapRenderingContext", - "ImageCapture", - "ImageData", - "ImageDecoder", - "ImageTrack", - "ImageTrackList", - "Ink", - "InkPresenter", - "InputDeviceCapabilities", - "InputDeviceInfo", - "InputEvent", - "InstallEvent", - "IntersectionObserver", - "IntersectionObserverEntry", - "InterventionReportBody", - "Keyboard", - "KeyboardEvent", - "KeyboardLayoutMap", - "KeyframeEffect", - "LargestContentfulPaint", - "LaunchParams", - "LaunchQueue", - "LayoutShift", - "LayoutShiftAttribution", - "LinearAccelerationSensor", - "Location", - "Lock", - "LockManager", - "Magnetometer", - "MathMLElement", - "MediaCapabilities", - "MediaDeviceInfo", - "MediaDevices", - "MediaElementAudioSourceNode", - "MediaEncryptedEvent", - "MediaError", - "MediaImage", - "MediaKeyMessageEvent", - "MediaKeys", - "MediaKeySession", - "MediaKeyStatusMap", - "MediaKeySystemAccess", - "MediaList", - "MediaMetadata", - "MediaQueryList", - "MediaQueryListEvent", - "MediaRecorder", - "MediaRecorderErrorEvent", - "MediaSession", - "MediaSource", - "MediaSourceHandle", - "MediaStream", - "MediaStreamAudioDestinationNode", - "MediaStreamAudioSourceNode", - "MediaStreamEvent", - "MediaStreamTrack", - "MediaStreamTrackAudioSourceNode", - "MediaStreamTrackEvent", - "MediaStreamTrackGenerator", - "MediaStreamTrackProcessor", - "MediaTrackConstraints", - "MediaTrackSettings", - "MediaTrackSupportedConstraints", - "MerchantValidationEvent", - "MessageChannel", - "MessageEvent", - "MessagePort", - "Metadata", - "MIDIAccess", - "MIDIConnectionEvent", - "MIDIInput", - "MIDIInputMap", - "MIDIMessageEvent", - "MIDIOutput", - "MIDIOutputMap", - "MIDIPort", - "MimeType", - "MimeTypeArray", - "MouseEvent", - "MouseScrollEvent", - "MutationEvent", - "MutationObserver", - "MutationRecord", - "NamedNodeMap", - "NavigateEvent", - "Navigation", - "NavigationCurrentEntryChangeEvent", - "NavigationDestination", - "NavigationHistoryEntry", - "NavigationPreloadManager", - "NavigationTransition", - "Navigator", - "NavigatorUAData", - "NDEFMessage", - "NDEFReader", - "NDEFReadingEvent", - "NDEFRecord", - "NetworkInformation", - "Node", - "NodeIterator", - "NodeList", - "Notification", - "NotificationEvent", - "NotifyAudioAvailableEvent", - "OES_draw_buffers_indexed", - "OfflineAudioCompletionEvent", - "OfflineAudioContext", - "OffscreenCanvas", - "OffscreenCanvasRenderingContext2D", - "OrientationSensor", - "OscillatorNode", - "OTPCredential", - "OverconstrainedError", - "PageTransitionEvent", - "PaintWorkletGlobalScope", - "PannerNode", - "PasswordCredential", - "Path2D", - "PaymentAddress", - "PaymentManager", - "PaymentMethodChangeEvent", - "PaymentRequest", - "PaymentRequestEvent", - "PaymentRequestUpdateEvent", - "PaymentResponse", - "Pbkdf2Params", - "Performance", - "PerformanceElementTiming", - "PerformanceEntry", - "PerformanceEventTiming", - "PerformanceLongTaskTiming", - "PerformanceMark", - "PerformanceMeasure", - "PerformanceNavigation", - "PerformanceNavigationTiming", - "PerformanceObserver", - "PerformanceObserverEntryList", - "PerformancePaintTiming", - "PerformanceResourceTiming", - "PerformanceServerTiming", - "PerformanceTiming", - "PeriodicSyncEvent", - "PeriodicSyncManager", - "PeriodicWave", - "Permissions", - "PermissionStatus", - "PictureInPictureEvent", - "PictureInPictureWindow", - "Plugin", - "PluginArray", - "Point", - "PointerEvent", - "PopStateEvent", - "PositionSensorVRDevice", - "Presentation", - "PresentationAvailability", - "PresentationConnection", - "PresentationConnectionAvailableEvent", - "PresentationConnectionCloseEvent", - "PresentationConnectionList", - "PresentationReceiver", - "PresentationRequest", - "ProcessingInstruction", - "ProgressEvent", - "PromiseRejectionEvent", - "PublicKeyCredential", - "PushEvent", - "PushManager", - "PushMessageData", - "PushSubscription", - "PushSubscriptionOptions", - "RadioNodeList", - "Range", - "ReadableByteStreamController", - "ReadableStream", - "ReadableStreamBYOBReader", - "ReadableStreamBYOBRequest", - "ReadableStreamDefaultController", - "ReadableStreamDefaultReader", - "RelativeOrientationSensor", - "RemotePlayback", - "Report", - "ReportBody", - "ReportingObserver", - "Request", - "ResizeObserver", - "ResizeObserverEntry", - "ResizeObserverSize", - "Response", - "RsaHashedImportParams", - "RsaHashedKeyGenParams", - "RsaOaepParams", - "RsaPssParams", - "RTCAudioSourceStats", - "RTCCertificate", - "RTCDataChannel", - "RTCDataChannelEvent", - "RTCDtlsTransport", - "RTCDTMFSender", - "RTCDTMFToneChangeEvent", - "RTCError", - "RTCErrorEvent", - "RTCIceCandidate", - "RTCIceCandidatePair", - "RTCIceCandidatePairStats", - "RTCIceCandidateStats", - "RTCIceParameters", - "RTCIceServer", - "RTCIceTransport", - "RTCIdentityAssertion", - "RTCInboundRtpStreamStats", - "RTCOutboundRtpStreamStats", - "RTCPeerConnection", - "RTCPeerConnectionIceErrorEvent", - "RTCPeerConnectionIceEvent", - "RTCPeerConnectionStats", - "RTCRemoteOutboundRtpStreamStats", - "RTCRtpCodecParameters", - "RTCRtpContributingSource", - "RTCRtpEncodingParameters", - "RTCRtpReceiver", - "RTCRtpSender", - "RTCRtpStreamStats", - "RTCRtpTransceiver", - "RTCSctpTransport", - "RTCSessionDescription", - "RTCStatsReport", - "RTCTrackEvent", - "Sanitizer", - "Scheduler", - "Screen", - "ScreenOrientation", - "ScriptProcessorNode", - "SecurityPolicyViolationEvent", - "Selection", - "Sensor", - "SensorErrorEvent", - "Serial", - "SerialPort", - "ServiceWorker", - "ServiceWorkerContainer", - "ServiceWorkerGlobalScope", - "ServiceWorkerRegistration", - "ShadowRoot", - "SharedWorker", - "SharedWorkerGlobalScope", - "SourceBuffer", - "SourceBufferList", - "SpeechGrammar", - "SpeechGrammarList", - "SpeechRecognition", - "SpeechRecognitionAlternative", - "SpeechRecognitionErrorEvent", - "SpeechRecognitionEvent", - "SpeechRecognitionResult", - "SpeechRecognitionResultList", - "SpeechSynthesis", - "SpeechSynthesisErrorEvent", - "SpeechSynthesisEvent", - "SpeechSynthesisUtterance", - "SpeechSynthesisVoice", - "StaticRange", - "StereoPannerNode", - "Storage", - "StorageEvent", - "StorageManager", - "StylePropertyMap", - "StylePropertyMapReadOnly", - "StyleSheet", - "StyleSheetList", - "SubmitEvent", - "SubtleCrypto", - "SVGAElement", - "SVGAngle", - "SVGAnimateColorElement", - "SVGAnimatedAngle", - "SVGAnimatedBoolean", - "SVGAnimatedEnumeration", - "SVGAnimatedInteger", - "SVGAnimatedLength", - "SVGAnimatedLengthList", - "SVGAnimatedNumber", - "SVGAnimatedNumberList", - "SVGAnimatedPreserveAspectRatio", - "SVGAnimatedRect", - "SVGAnimatedString", - "SVGAnimatedTransformList", - "SVGAnimateElement", - "SVGAnimateMotionElement", - "SVGAnimateTransformElement", - "SVGAnimationElement", - "SVGCircleElement", - "SVGClipPathElement", - "SVGComponentTransferFunctionElement", - "SVGCursorElement", - "SVGDefsElement", - "SVGDescElement", - "SVGElement", - "SVGEllipseElement", - "SVGEvent", - "SVGFEBlendElement", - "SVGFEColorMatrixElement", - "SVGFEComponentTransferElement", - "SVGFECompositeElement", - "SVGFEConvolveMatrixElement", - "SVGFEDiffuseLightingElement", - "SVGFEDisplacementMapElement", - "SVGFEDistantLightElement", - "SVGFEDropShadowElement", - "SVGFEFloodElement", - "SVGFEFuncAElement", - "SVGFEFuncBElement", - "SVGFEFuncGElement", - "SVGFEFuncRElement", - "SVGFEGaussianBlurElement", - "SVGFEImageElement", - "SVGFEMergeElement", - "SVGFEMergeNodeElement", - "SVGFEMorphologyElement", - "SVGFEOffsetElement", - "SVGFEPointLightElement", - "SVGFESpecularLightingElement", - "SVGFESpotLightElement", - "SVGFETileElement", - "SVGFETurbulenceElement", - "SVGFilterElement", - "SVGFontElement", - "SVGFontFaceElement", - "SVGFontFaceFormatElement", - "SVGFontFaceNameElement", - "SVGFontFaceSrcElement", - "SVGFontFaceUriElement", - "SVGForeignObjectElement", - "SVGGElement", - "SVGGeometryElement", - "SVGGlyphElement", - "SVGGlyphRefElement", - "SVGGradientElement", - "SVGGraphicsElement", - "SVGHKernElement", - "SVGImageElement", - "SVGLength", - "SVGLengthList", - "SVGLinearGradientElement", - "SVGLineElement", - "SVGMarkerElement", - "SVGMaskElement", - "SVGMetadataElement", - "SVGMissingGlyphElement", - "SVGMPathElement", - "SVGNumber", - "SVGNumberList", - "SVGPathElement", - "SVGPatternElement", - "SVGPoint", - "SVGPointList", - "SVGPolygonElement", - "SVGPolylineElement", - "SVGPreserveAspectRatio", - "SVGRadialGradientElement", - "SVGRect", - "SVGRectElement", - "SVGRenderingIntent", - "SVGScriptElement", - "SVGSetElement", - "SVGStopElement", - "SVGStringList", - "SVGStyleElement", - "SVGSVGElement", - "SVGSwitchElement", - "SVGSymbolElement", - "SVGTextContentElement", - "SVGTextElement", - "SVGTextPathElement", - "SVGTextPositioningElement", - "SVGTitleElement", - "SVGTransform", - "SVGTransformList", - "SVGTRefElement", - "SVGTSpanElement", - "SVGUnitTypes", - "SVGUseElement", - "SVGViewElement", - "SVGVKernElement", - "SyncEvent", - "SyncManager", - "TaskAttributionTiming", - "TaskController", - "TaskPriorityChangeEvent", - "TaskSignal", - "Text", - "TextDecoder", - "TextDecoderStream", - "TextEncoder", - "TextEncoderStream", - "TextMetrics", - "TextTrack", - "TextTrackCue", - "TextTrackCueList", - "TextTrackList", - "TimeEvent", - "TimeRanges", - "ToggleEvent", - "Touch", - "TouchEvent", - "TouchList", - "TrackEvent", - "TransformStream", - "TransformStreamDefaultController", - "TransitionEvent", - "TreeWalker", - "TrustedHTML", - "TrustedScript", - "TrustedScriptURL", - "TrustedTypePolicy", - "TrustedTypePolicyFactory", - "UIEvent", - "URL", - "URLPattern", - "URLSearchParams", - "USB", - "USBAlternateInterface", - "USBConfiguration", - "USBConnectionEvent", - "USBDevice", - "USBEndpoint", - "USBInterface", - "USBInTransferResult", - "USBIsochronousInTransferPacket", - "USBIsochronousInTransferResult", - "USBIsochronousOutTransferPacket", - "USBIsochronousOutTransferResult", - "USBOutTransferResult", - "UserActivation", - "ValidityState", - "VideoColorSpace", - "VideoDecoder", - "VideoEncoder", - "VideoFrame", - "VideoPlaybackQuality", - "VideoTrack", - "VideoTrackList", - "ViewTransition", - "VirtualKeyboard", - "VisualViewport", - "VRDisplay", - "VRDisplayCapabilities", - "VRDisplayEvent", - "VREyeParameters", - "VRFieldOfView", - "VRFrameData", - "VRLayerInit", - "VRPose", - "VRStageParameters", - "VTTCue", - "VTTRegion", - "WakeLock", - "WakeLockSentinel", - "WaveShaperNode", - "WebGL2RenderingContext", - "WebGLActiveInfo", - "WebGLBuffer", - "WebGLContextEvent", - "WebGLFramebuffer", - "WebGLObject", - "WebGLProgram", - "WebGLQuery", - "WebGLRenderbuffer", - "WebGLRenderingContext", - "WebGLSampler", - "WebGLShader", - "WebGLShaderPrecisionFormat", - "WebGLSync", - "WebGLTexture", - "WebGLTransformFeedback", - "WebGLUniformLocation", - "WebGLVertexArrayObject", - "WebSocket", - "WebTransport", - "WebTransportBidirectionalStream", - "WebTransportDatagramDuplexStream", - "WebTransportError", - "WebTransportReceiveStream", - "WheelEvent", - "Window", - "WindowClient", - "WindowControlsOverlay", - "WindowControlsOverlayGeometryChangeEvent", - "Worker", - "WorkerGlobalScope", - "WorkerLocation", - "WorkerNavigator", - "Worklet", - "WorkletGlobalScope", - "WritableStream", - "WritableStreamDefaultController", - "WritableStreamDefaultWriter", - "XMLDocument", - "XMLHttpRequest", - "XMLHttpRequestEventTarget", - "XMLHttpRequestUpload", - "XMLSerializer", - "XPathEvaluator", - "XPathException", - "XPathExpression", - "XPathNSResolver", - "XPathResult", - "XRAnchor", - "XRAnchorSet", - "XRBoundedReferenceSpace", - "XRCompositionLayer", - "XRCPUDepthInformation", - "XRCubeLayer", - "XRCylinderLayer", - "XRDepthInformation", - "XREquirectLayer", - "XRFrame", - "XRHand", - "XRHitTestResult", - "XRHitTestSource", - "XRInputSource", - "XRInputSourceArray", - "XRInputSourceEvent", - "XRInputSourcesChangeEvent", - "XRJointPose", - "XRJointSpace", - "XRLayer", - "XRLayerEvent", - "XRLightEstimate", - "XRLightProbe", - "XRMediaBinding", - "XRPose", - "XRProjectionLayer", - "XRQuadLayer", - "XRRay", - "XRReferenceSpace", - "XRReferenceSpaceEvent", - "XRRenderState", - "XRRigidTransform", - "XRSession", - "XRSessionEvent", - "XRSpace", - "XRSubImage", - "XRSystem", - "XRTransientInputHitTestResult", - "XRTransientInputHitTestSource", - "XRView", - "XRViewerPose", - "XRViewport", - "XRWebGLBinding", - "XRWebGLDepthInformation", - "XRWebGLLayer", - "XRWebGLSubImage", - "XSLTProcessor" - ) + val builtins: Set[String] = Set( + "AggregateError", + "Array", + "ArrayBuffer", + "AsyncFunction", + "AsyncGenerator", + "AsyncGeneratorFunction", + "AsyncIterator", + "Atomics", + "BigInt", + "BigInt64Array", + "BigUint64Array", + "Boolean", + "Buffer.from", + "DataView", + "Date", + "Error", + "EvalError", + "FinalizationRegistry", + "Float32Array", + "Float64Array", + "Function", + "Generator", + "GeneratorFunction", + "HTMLImageElement", + "Iterator", + "Infinity", + "Int16Array", + "Int32Array", + "Int8Array", + "InternalError", + "Intl", + "Intl.Collator", + "Intl.DateTimeFormat", + "Intl.DisplayNames", + "Intl.DurationFormat", + "Intl.ListFormat", + "Intl.Locale", + "Intl.NumberFormat", + "Intl.PluralRules", + "Intl.RelativeTimeFormat", + "Intl.Segmenter", + "JSON", + "JSON.parse", + "JSON.stringify", + "Map", + "Math", + "NaN", + "Number", + "Number.isFinite", + "Number.isInteger", + "Number.isNaN", + "Number.isSafeInteger", + "Number.parseFloat", + "Number.parseInt", + "Number.prototype.toExponential", + "Number.prototype.toFixed", + "Number.prototype.toLocaleString", + "Number.prototype.toPrecision", + "Number.prototype.toSource", + "Number.prototype.toString", + "Number.prototype.valueOf", + "Object", + "Object.assign", + "Object.create", + "Object.defineProperties", + "Object.defineProperty", + "Object.entries", + "Object.freeze", + "Object.fromEntries", + "Object.getOwnPropertyDescriptor", + "Object.getOwnPropertyDescriptors", + "Object.getOwnPropertyNames", + "Object.getOwnPropertySymbols", + "Object.getPrototypeOf", + "Object.is", + "Object.isExtensible", + "Object.isFrozen", + "Object.isSealed", + "Object.keys", + "Object.preventExtensions", + "Object.prototype.__defineGetter__", + "Object.prototype.__defineSetter__", + "Object.prototype.__lookupGetter__", + "Object.prototype.__lookupSetter__", + "Object.prototype.hasOwnProperty", + "Object.prototype.isPrototypeOf", + "Object.prototype.propertyIsEnumerable", + "Object.prototype.toLocaleString", + "Object.prototype.toSource", + "Object.prototype.toString", + "Object.prototype.valueOf", + "Object.seal", + "Object.setPrototypeOf", + "Object.values", + "Promise", + "Promise.all", + "Promise.allSettled", + "Promise.any", + "Promise.race", + "Promise.reject", + "Promise.resolve", + "Proxy", + "RangeError", + "ReferenceError", + "Reflect", + "RegExp", + "Set", + "SharedArrayBuffer", + "String", + "Symbol", + "SyntaxError", + "TypeError", + "TypedArray", + "URIError", + "Uint16Array", + "Uint32Array", + "Uint8Array", + "Uint8ClampedArray", + "WeakMap", + "WeakRef", + "WeakSet", + "decodeURI", + "decodeURIComponent", + "encodeURI", + "encodeURIComponent", + "escape", + "eval", + "eval", + "fetch", + "globalThis", + "isFinite", + "isNaN", + "localStorage.setItem", + "parseFloat", + "parseInt", + "undefined", + "unescape", + "uneval", + "AbortController", + "AbortSignal", + "AbsoluteOrientationSensor", + "AbstractRange", + "Accelerometer", + "AesCbcParams", + "AesCtrParams", + "AesGcmParams", + "AesKeyGenParams", + "AmbientLightSensor", + "AnalyserNode", + "ANGLE_instanced_arrays", + "Animation", + "AnimationEffect", + "AnimationEvent", + "AnimationPlaybackEvent", + "AnimationTimeline", + "Attr", + "AudioBuffer", + "AudioBufferSourceNode", + "AudioContext", + "AudioData", + "AudioDecoder", + "AudioDestinationNode", + "AudioEncoder", + "AudioListener", + "AudioNode", + "AudioParam", + "AudioParamDescriptor", + "AudioParamMap", + "AudioProcessingEvent", + "AudioScheduledSourceNode", + "AudioSinkInfo", + "AudioTrack", + "AudioTrackList", + "AudioWorklet", + "AudioWorkletGlobalScope", + "AudioWorkletNode", + "AudioWorkletProcessor", + "AuthenticatorAssertionResponse", + "AuthenticatorAttestationResponse", + "AuthenticatorResponse", + "BackgroundFetchEvent", + "BackgroundFetchManager", + "BackgroundFetchRecord", + "BackgroundFetchRegistration", + "BackgroundFetchUpdateUIEvent", + "BarcodeDetector", + "BarProp", + "BaseAudioContext", + "BatteryManager", + "BeforeInstallPromptEvent", + "BeforeUnloadEvent", + "BiquadFilterNode", + "Blob", + "BlobEvent", + "Bluetooth", + "BluetoothCharacteristicProperties", + "BluetoothDevice", + "BluetoothRemoteGATTCharacteristic", + "BluetoothRemoteGATTDescriptor", + "BluetoothRemoteGATTServer", + "BluetoothRemoteGATTService", + "BluetoothUUID", + "BroadcastChannel", + "ByteLengthQueuingStrategy", + "Cache", + "CacheStorage", + "CanMakePaymentEvent", + "CanvasCaptureMediaStreamTrack", + "CanvasGradient", + "CanvasPattern", + "CanvasRenderingContext2D", + "CaptureController", + "CaretPosition", + "CDATASection", + "ChannelMergerNode", + "ChannelSplitterNode", + "CharacterData", + "Client", + "Clients", + "Clipboard", + "ClipboardEvent", + "ClipboardItem", + "CloseEvent", + "Comment", + "CompositionEvent", + "CompressionStream", + "console", + "ConstantSourceNode", + "ContactAddress", + "ContactsManager", + "ContentIndex", + "ContentIndexEvent", + "ContentVisibilityAutoStateChangeEvent", + "ConvolverNode", + "CookieChangeEvent", + "CookieStore", + "CookieStoreManager", + "CountQueuingStrategy", + "Credential", + "CredentialsContainer", + "Crypto", + "CryptoKey", + "CryptoKeyPair", + "CSPViolationReportBody", + "CSS", + "CSSAnimation", + "CSSConditionRule", + "CSSContainerRule", + "CSSCounterStyleRule", + "CSSFontFaceRule", + "CSSFontFeatureValuesRule", + "CSSFontPaletteValuesRule", + "CSSGroupingRule", + "CSSImageValue", + "CSSImportRule", + "CSSKeyframeRule", + "CSSKeyframesRule", + "CSSKeywordValue", + "CSSLayerBlockRule", + "CSSLayerStatementRule", + "CSSMathInvert", + "CSSMathMax", + "CSSMathMin", + "CSSMathNegate", + "CSSMathProduct", + "CSSMathSum", + "CSSMathValue", + "CSSMatrixComponent", + "CSSMediaRule", + "CSSNamespaceRule", + "CSSNumericArray", + "CSSNumericValue", + "CSSPageRule", + "CSSPerspective", + "CSSPositionValue", + "CSSPrimitiveValue", + "CSSPropertyRule", + "CSSPseudoElement", + "CSSRotate", + "CSSRule", + "CSSRuleList", + "CSSScale", + "CSSSkew", + "CSSSkewX", + "CSSSkewY", + "CSSStyleDeclaration", + "CSSStyleRule", + "CSSStyleSheet", + "CSSStyleValue", + "CSSSupportsRule", + "CSSTransformComponent", + "CSSTransformValue", + "CSSTransition", + "CSSTranslate", + "CSSUnitValue", + "CSSUnparsedValue", + "CSSValue", + "CSSValueList", + "CSSVariableReferenceValue", + "CustomElementRegistry", + "CustomEvent", + "CustomStateSet", + "DataTransfer", + "DataTransferItem", + "DataTransferItemList", + "DecompressionStream", + "DedicatedWorkerGlobalScope", + "DelayNode", + "DeprecationReportBody", + "DeviceMotionEvent", + "DeviceMotionEventAcceleration", + "DeviceMotionEventRotationRate", + "DeviceOrientationEvent", + "DirectoryEntrySync", + "DirectoryReaderSync", + "Document", + "DocumentFragment", + "DocumentTimeline", + "DocumentType", + "DOMError", + "DOMException", + "DOMHighResTimeStamp", + "DOMImplementation", + "DOMMatrix (WebKitCSSMatrix)", + "DOMMatrixReadOnly", + "DOMParser", + "DOMPoint", + "DOMPointReadOnly", + "DOMQuad", + "DOMRect", + "DOMRectReadOnly", + "DOMStringList", + "DOMStringMap", + "DOMTokenList", + "DragEvent", + "DynamicsCompressorNode", + "EcdhKeyDeriveParams", + "EcdsaParams", + "EcKeyGenParams", + "EcKeyImportParams", + "Element", + "ElementInternals", + "EncodedAudioChunk", + "EncodedVideoChunk", + "ErrorEvent", + "Event", + "EventCounts", + "EventSource", + "EventTarget", + "ExtendableCookieChangeEvent", + "ExtendableEvent", + "ExtendableMessageEvent", + "EyeDropper", + "FeaturePolicy", + "FederatedCredential", + "FetchEvent", + "File", + "FileEntrySync", + "FileList", + "FileReader", + "FileReaderSync", + "FileSystem", + "FileSystemDirectoryEntry", + "FileSystemDirectoryHandle", + "FileSystemDirectoryReader", + "FileSystemEntry", + "FileSystemFileEntry", + "FileSystemFileHandle", + "FileSystemHandle", + "FileSystemSync", + "FileSystemSyncAccessHandle", + "FileSystemWritableFileStream", + "FocusEvent", + "FontData", + "FontFace", + "FontFaceSet", + "FontFaceSetLoadEvent", + "FormData", + "FormDataEvent", + "FragmentDirective", + "GainNode", + "Gamepad", + "GamepadButton", + "GamepadEvent", + "GamepadHapticActuator", + "GamepadPose", + "Geolocation", + "GeolocationCoordinates", + "GeolocationPosition", + "GeolocationPositionError", + "GestureEvent", + "GPU", + "GPUAdapter", + "GPUAdapterInfo", + "GPUBindGroup", + "GPUBindGroupLayout", + "GPUBuffer", + "GPUCanvasContext", + "GPUCommandBuffer", + "GPUCommandEncoder", + "GPUCompilationInfo", + "GPUCompilationMessage", + "GPUComputePassEncoder", + "GPUComputePipeline", + "GPUDevice", + "GPUDeviceLostInfo", + "GPUError", + "GPUExternalTexture", + "GPUInternalError", + "GPUOutOfMemoryError", + "GPUPipelineError", + "GPUPipelineLayout", + "GPUQuerySet", + "GPUQueue", + "GPURenderBundle", + "GPURenderBundleEncoder", + "GPURenderPassEncoder", + "GPURenderPipeline", + "GPUSampler", + "GPUShaderModule", + "GPUSupportedFeatures", + "GPUSupportedLimits", + "GPUTexture", + "GPUTextureView", + "GPUUncapturedErrorEvent", + "GPUValidationError", + "GravitySensor", + "Gyroscope", + "HashChangeEvent", + "Headers", + "HID", + "HIDConnectionEvent", + "HIDDevice", + "HIDInputReportEvent", + "Highlight", + "HighlightRegistry", + "History", + "HkdfParams", + "HmacImportParams", + "HmacKeyGenParams", + "HMDVRDevice", + "HTMLAnchorElement", + "HTMLAreaElement", + "HTMLAudioElement", + "HTMLBaseElement", + "HTMLBodyElement", + "HTMLBRElement", + "HTMLButtonElement", + "HTMLCanvasElement", + "HTMLCollection", + "HTMLDataElement", + "HTMLDataListElement", + "HTMLDetailsElement", + "HTMLDialogElement", + "HTMLDivElement", + "HTMLDListElement", + "HTMLDocument", + "HTMLElement", + "HTMLEmbedElement", + "HTMLFieldSetElement", + "HTMLFontElement", + "HTMLFormControlsCollection", + "HTMLFormElement", + "HTMLFrameSetElement", + "HTMLHeadElement", + "HTMLHeadingElement", + "HTMLHRElement", + "HTMLHtmlElement", + "HTMLIFrameElement", + "HTMLImageElement", + "HTMLInputElement", + "HTMLLabelElement", + "HTMLLegendElement", + "HTMLLIElement", + "HTMLLinkElement", + "HTMLMapElement", + "HTMLMarqueeElement", + "HTMLMediaElement", + "HTMLMenuElement", + "HTMLMenuItemElement", + "HTMLMetaElement", + "HTMLMeterElement", + "HTMLModElement", + "HTMLObjectElement", + "HTMLOListElement", + "HTMLOptGroupElement", + "HTMLOptionElement", + "HTMLOptionsCollection", + "HTMLOutputElement", + "HTMLParagraphElement", + "HTMLParamElement", + "HTMLPictureElement", + "HTMLPreElement", + "HTMLProgressElement", + "HTMLQuoteElement", + "HTMLScriptElement", + "HTMLSelectElement", + "HTMLSlotElement", + "HTMLSourceElement", + "HTMLSpanElement", + "HTMLStyleElement", + "HTMLTableCaptionElement", + "HTMLTableCellElement", + "HTMLTableColElement", + "HTMLTableElement", + "HTMLTableRowElement", + "HTMLTableSectionElement", + "HTMLTemplateElement", + "HTMLTextAreaElement", + "HTMLTimeElement", + "HTMLTitleElement", + "HTMLTrackElement", + "HTMLUListElement", + "HTMLUnknownElement", + "HTMLVideoElement", + "IDBCursor", + "IDBCursorWithValue", + "IDBDatabase", + "IDBFactory", + "IDBIndex", + "IDBKeyRange", + "IDBLocaleAwareKeyRange", + "IDBObjectStore", + "IDBOpenDBRequest", + "IDBRequest", + "IDBTransaction", + "IDBVersionChangeEvent", + "IdentityCredential", + "IdleDeadline", + "IdleDetector", + "IIRFilterNode", + "ImageBitmap", + "ImageBitmapRenderingContext", + "ImageCapture", + "ImageData", + "ImageDecoder", + "ImageTrack", + "ImageTrackList", + "Ink", + "InkPresenter", + "InputDeviceCapabilities", + "InputDeviceInfo", + "InputEvent", + "InstallEvent", + "IntersectionObserver", + "IntersectionObserverEntry", + "InterventionReportBody", + "Keyboard", + "KeyboardEvent", + "KeyboardLayoutMap", + "KeyframeEffect", + "LargestContentfulPaint", + "LaunchParams", + "LaunchQueue", + "LayoutShift", + "LayoutShiftAttribution", + "LinearAccelerationSensor", + "Location", + "Lock", + "LockManager", + "Magnetometer", + "MathMLElement", + "MediaCapabilities", + "MediaDeviceInfo", + "MediaDevices", + "MediaElementAudioSourceNode", + "MediaEncryptedEvent", + "MediaError", + "MediaImage", + "MediaKeyMessageEvent", + "MediaKeys", + "MediaKeySession", + "MediaKeyStatusMap", + "MediaKeySystemAccess", + "MediaList", + "MediaMetadata", + "MediaQueryList", + "MediaQueryListEvent", + "MediaRecorder", + "MediaRecorderErrorEvent", + "MediaSession", + "MediaSource", + "MediaSourceHandle", + "MediaStream", + "MediaStreamAudioDestinationNode", + "MediaStreamAudioSourceNode", + "MediaStreamEvent", + "MediaStreamTrack", + "MediaStreamTrackAudioSourceNode", + "MediaStreamTrackEvent", + "MediaStreamTrackGenerator", + "MediaStreamTrackProcessor", + "MediaTrackConstraints", + "MediaTrackSettings", + "MediaTrackSupportedConstraints", + "MerchantValidationEvent", + "MessageChannel", + "MessageEvent", + "MessagePort", + "Metadata", + "MIDIAccess", + "MIDIConnectionEvent", + "MIDIInput", + "MIDIInputMap", + "MIDIMessageEvent", + "MIDIOutput", + "MIDIOutputMap", + "MIDIPort", + "MimeType", + "MimeTypeArray", + "MouseEvent", + "MouseScrollEvent", + "MutationEvent", + "MutationObserver", + "MutationRecord", + "NamedNodeMap", + "NavigateEvent", + "Navigation", + "NavigationCurrentEntryChangeEvent", + "NavigationDestination", + "NavigationHistoryEntry", + "NavigationPreloadManager", + "NavigationTransition", + "Navigator", + "NavigatorUAData", + "NDEFMessage", + "NDEFReader", + "NDEFReadingEvent", + "NDEFRecord", + "NetworkInformation", + "Node", + "NodeIterator", + "NodeList", + "Notification", + "NotificationEvent", + "NotifyAudioAvailableEvent", + "OES_draw_buffers_indexed", + "OfflineAudioCompletionEvent", + "OfflineAudioContext", + "OffscreenCanvas", + "OffscreenCanvasRenderingContext2D", + "OrientationSensor", + "OscillatorNode", + "OTPCredential", + "OverconstrainedError", + "PageTransitionEvent", + "PaintWorkletGlobalScope", + "PannerNode", + "PasswordCredential", + "Path2D", + "PaymentAddress", + "PaymentManager", + "PaymentMethodChangeEvent", + "PaymentRequest", + "PaymentRequestEvent", + "PaymentRequestUpdateEvent", + "PaymentResponse", + "Pbkdf2Params", + "Performance", + "PerformanceElementTiming", + "PerformanceEntry", + "PerformanceEventTiming", + "PerformanceLongTaskTiming", + "PerformanceMark", + "PerformanceMeasure", + "PerformanceNavigation", + "PerformanceNavigationTiming", + "PerformanceObserver", + "PerformanceObserverEntryList", + "PerformancePaintTiming", + "PerformanceResourceTiming", + "PerformanceServerTiming", + "PerformanceTiming", + "PeriodicSyncEvent", + "PeriodicSyncManager", + "PeriodicWave", + "Permissions", + "PermissionStatus", + "PictureInPictureEvent", + "PictureInPictureWindow", + "Plugin", + "PluginArray", + "Point", + "PointerEvent", + "PopStateEvent", + "PositionSensorVRDevice", + "Presentation", + "PresentationAvailability", + "PresentationConnection", + "PresentationConnectionAvailableEvent", + "PresentationConnectionCloseEvent", + "PresentationConnectionList", + "PresentationReceiver", + "PresentationRequest", + "ProcessingInstruction", + "ProgressEvent", + "PromiseRejectionEvent", + "PublicKeyCredential", + "PushEvent", + "PushManager", + "PushMessageData", + "PushSubscription", + "PushSubscriptionOptions", + "RadioNodeList", + "Range", + "ReadableByteStreamController", + "ReadableStream", + "ReadableStreamBYOBReader", + "ReadableStreamBYOBRequest", + "ReadableStreamDefaultController", + "ReadableStreamDefaultReader", + "RelativeOrientationSensor", + "RemotePlayback", + "Report", + "ReportBody", + "ReportingObserver", + "Request", + "ResizeObserver", + "ResizeObserverEntry", + "ResizeObserverSize", + "Response", + "RsaHashedImportParams", + "RsaHashedKeyGenParams", + "RsaOaepParams", + "RsaPssParams", + "RTCAudioSourceStats", + "RTCCertificate", + "RTCDataChannel", + "RTCDataChannelEvent", + "RTCDtlsTransport", + "RTCDTMFSender", + "RTCDTMFToneChangeEvent", + "RTCError", + "RTCErrorEvent", + "RTCIceCandidate", + "RTCIceCandidatePair", + "RTCIceCandidatePairStats", + "RTCIceCandidateStats", + "RTCIceParameters", + "RTCIceServer", + "RTCIceTransport", + "RTCIdentityAssertion", + "RTCInboundRtpStreamStats", + "RTCOutboundRtpStreamStats", + "RTCPeerConnection", + "RTCPeerConnectionIceErrorEvent", + "RTCPeerConnectionIceEvent", + "RTCPeerConnectionStats", + "RTCRemoteOutboundRtpStreamStats", + "RTCRtpCodecParameters", + "RTCRtpContributingSource", + "RTCRtpEncodingParameters", + "RTCRtpReceiver", + "RTCRtpSender", + "RTCRtpStreamStats", + "RTCRtpTransceiver", + "RTCSctpTransport", + "RTCSessionDescription", + "RTCStatsReport", + "RTCTrackEvent", + "Sanitizer", + "Scheduler", + "Screen", + "ScreenOrientation", + "ScriptProcessorNode", + "SecurityPolicyViolationEvent", + "Selection", + "Sensor", + "SensorErrorEvent", + "Serial", + "SerialPort", + "ServiceWorker", + "ServiceWorkerContainer", + "ServiceWorkerGlobalScope", + "ServiceWorkerRegistration", + "ShadowRoot", + "SharedWorker", + "SharedWorkerGlobalScope", + "SourceBuffer", + "SourceBufferList", + "SpeechGrammar", + "SpeechGrammarList", + "SpeechRecognition", + "SpeechRecognitionAlternative", + "SpeechRecognitionErrorEvent", + "SpeechRecognitionEvent", + "SpeechRecognitionResult", + "SpeechRecognitionResultList", + "SpeechSynthesis", + "SpeechSynthesisErrorEvent", + "SpeechSynthesisEvent", + "SpeechSynthesisUtterance", + "SpeechSynthesisVoice", + "StaticRange", + "StereoPannerNode", + "Storage", + "StorageEvent", + "StorageManager", + "StylePropertyMap", + "StylePropertyMapReadOnly", + "StyleSheet", + "StyleSheetList", + "SubmitEvent", + "SubtleCrypto", + "SVGAElement", + "SVGAngle", + "SVGAnimateColorElement", + "SVGAnimatedAngle", + "SVGAnimatedBoolean", + "SVGAnimatedEnumeration", + "SVGAnimatedInteger", + "SVGAnimatedLength", + "SVGAnimatedLengthList", + "SVGAnimatedNumber", + "SVGAnimatedNumberList", + "SVGAnimatedPreserveAspectRatio", + "SVGAnimatedRect", + "SVGAnimatedString", + "SVGAnimatedTransformList", + "SVGAnimateElement", + "SVGAnimateMotionElement", + "SVGAnimateTransformElement", + "SVGAnimationElement", + "SVGCircleElement", + "SVGClipPathElement", + "SVGComponentTransferFunctionElement", + "SVGCursorElement", + "SVGDefsElement", + "SVGDescElement", + "SVGElement", + "SVGEllipseElement", + "SVGEvent", + "SVGFEBlendElement", + "SVGFEColorMatrixElement", + "SVGFEComponentTransferElement", + "SVGFECompositeElement", + "SVGFEConvolveMatrixElement", + "SVGFEDiffuseLightingElement", + "SVGFEDisplacementMapElement", + "SVGFEDistantLightElement", + "SVGFEDropShadowElement", + "SVGFEFloodElement", + "SVGFEFuncAElement", + "SVGFEFuncBElement", + "SVGFEFuncGElement", + "SVGFEFuncRElement", + "SVGFEGaussianBlurElement", + "SVGFEImageElement", + "SVGFEMergeElement", + "SVGFEMergeNodeElement", + "SVGFEMorphologyElement", + "SVGFEOffsetElement", + "SVGFEPointLightElement", + "SVGFESpecularLightingElement", + "SVGFESpotLightElement", + "SVGFETileElement", + "SVGFETurbulenceElement", + "SVGFilterElement", + "SVGFontElement", + "SVGFontFaceElement", + "SVGFontFaceFormatElement", + "SVGFontFaceNameElement", + "SVGFontFaceSrcElement", + "SVGFontFaceUriElement", + "SVGForeignObjectElement", + "SVGGElement", + "SVGGeometryElement", + "SVGGlyphElement", + "SVGGlyphRefElement", + "SVGGradientElement", + "SVGGraphicsElement", + "SVGHKernElement", + "SVGImageElement", + "SVGLength", + "SVGLengthList", + "SVGLinearGradientElement", + "SVGLineElement", + "SVGMarkerElement", + "SVGMaskElement", + "SVGMetadataElement", + "SVGMissingGlyphElement", + "SVGMPathElement", + "SVGNumber", + "SVGNumberList", + "SVGPathElement", + "SVGPatternElement", + "SVGPoint", + "SVGPointList", + "SVGPolygonElement", + "SVGPolylineElement", + "SVGPreserveAspectRatio", + "SVGRadialGradientElement", + "SVGRect", + "SVGRectElement", + "SVGRenderingIntent", + "SVGScriptElement", + "SVGSetElement", + "SVGStopElement", + "SVGStringList", + "SVGStyleElement", + "SVGSVGElement", + "SVGSwitchElement", + "SVGSymbolElement", + "SVGTextContentElement", + "SVGTextElement", + "SVGTextPathElement", + "SVGTextPositioningElement", + "SVGTitleElement", + "SVGTransform", + "SVGTransformList", + "SVGTRefElement", + "SVGTSpanElement", + "SVGUnitTypes", + "SVGUseElement", + "SVGViewElement", + "SVGVKernElement", + "SyncEvent", + "SyncManager", + "TaskAttributionTiming", + "TaskController", + "TaskPriorityChangeEvent", + "TaskSignal", + "Text", + "TextDecoder", + "TextDecoderStream", + "TextEncoder", + "TextEncoderStream", + "TextMetrics", + "TextTrack", + "TextTrackCue", + "TextTrackCueList", + "TextTrackList", + "TimeEvent", + "TimeRanges", + "ToggleEvent", + "Touch", + "TouchEvent", + "TouchList", + "TrackEvent", + "TransformStream", + "TransformStreamDefaultController", + "TransitionEvent", + "TreeWalker", + "TrustedHTML", + "TrustedScript", + "TrustedScriptURL", + "TrustedTypePolicy", + "TrustedTypePolicyFactory", + "UIEvent", + "URL", + "URLPattern", + "URLSearchParams", + "USB", + "USBAlternateInterface", + "USBConfiguration", + "USBConnectionEvent", + "USBDevice", + "USBEndpoint", + "USBInterface", + "USBInTransferResult", + "USBIsochronousInTransferPacket", + "USBIsochronousInTransferResult", + "USBIsochronousOutTransferPacket", + "USBIsochronousOutTransferResult", + "USBOutTransferResult", + "UserActivation", + "ValidityState", + "VideoColorSpace", + "VideoDecoder", + "VideoEncoder", + "VideoFrame", + "VideoPlaybackQuality", + "VideoTrack", + "VideoTrackList", + "ViewTransition", + "VirtualKeyboard", + "VisualViewport", + "VRDisplay", + "VRDisplayCapabilities", + "VRDisplayEvent", + "VREyeParameters", + "VRFieldOfView", + "VRFrameData", + "VRLayerInit", + "VRPose", + "VRStageParameters", + "VTTCue", + "VTTRegion", + "WakeLock", + "WakeLockSentinel", + "WaveShaperNode", + "WebGL2RenderingContext", + "WebGLActiveInfo", + "WebGLBuffer", + "WebGLContextEvent", + "WebGLFramebuffer", + "WebGLObject", + "WebGLProgram", + "WebGLQuery", + "WebGLRenderbuffer", + "WebGLRenderingContext", + "WebGLSampler", + "WebGLShader", + "WebGLShaderPrecisionFormat", + "WebGLSync", + "WebGLTexture", + "WebGLTransformFeedback", + "WebGLUniformLocation", + "WebGLVertexArrayObject", + "WebSocket", + "WebTransport", + "WebTransportBidirectionalStream", + "WebTransportDatagramDuplexStream", + "WebTransportError", + "WebTransportReceiveStream", + "WheelEvent", + "Window", + "WindowClient", + "WindowControlsOverlay", + "WindowControlsOverlayGeometryChangeEvent", + "Worker", + "WorkerGlobalScope", + "WorkerLocation", + "WorkerNavigator", + "Worklet", + "WorkletGlobalScope", + "WritableStream", + "WritableStreamDefaultController", + "WritableStreamDefaultWriter", + "XMLDocument", + "XMLHttpRequest", + "XMLHttpRequestEventTarget", + "XMLHttpRequestUpload", + "XMLSerializer", + "XPathEvaluator", + "XPathException", + "XPathExpression", + "XPathNSResolver", + "XPathResult", + "XRAnchor", + "XRAnchorSet", + "XRBoundedReferenceSpace", + "XRCompositionLayer", + "XRCPUDepthInformation", + "XRCubeLayer", + "XRCylinderLayer", + "XRDepthInformation", + "XREquirectLayer", + "XRFrame", + "XRHand", + "XRHitTestResult", + "XRHitTestSource", + "XRInputSource", + "XRInputSourceArray", + "XRInputSourceEvent", + "XRInputSourcesChangeEvent", + "XRJointPose", + "XRJointSpace", + "XRLayer", + "XRLayerEvent", + "XRLightEstimate", + "XRLightProbe", + "XRMediaBinding", + "XRPose", + "XRProjectionLayer", + "XRQuadLayer", + "XRRay", + "XRReferenceSpace", + "XRReferenceSpaceEvent", + "XRRenderState", + "XRRigidTransform", + "XRSession", + "XRSessionEvent", + "XRSpace", + "XRSubImage", + "XRSystem", + "XRTransientInputHitTestResult", + "XRTransientInputHitTestSource", + "XRView", + "XRViewerPose", + "XRViewport", + "XRWebGLBinding", + "XRWebGLDepthInformation", + "XRWebGLLayer", + "XRWebGLSubImage", + "XSLTProcessor" + ) end GlobalBuiltins diff --git a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/ImportResolverPass.scala b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/ImportResolverPass.scala index e526688c..3b9d1cb8 100644 --- a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/ImportResolverPass.scala +++ b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/ImportResolverPass.scala @@ -13,127 +13,127 @@ import scala.util.{Failure, Success, Try} class ImportResolverPass(cpg: Cpg) extends XImportResolverPass(cpg): - private val pathPattern = Pattern.compile("[\"']([\\w/.]+)[\"']") + private val pathPattern = Pattern.compile("[\"']([\\w/.]+)[\"']") - override protected def optionalResolveImport( - fileName: String, - importCall: Call, - importedEntity: String, - importedAs: String, - diffGraph: DiffGraphBuilder - ): Unit = - val pathSep = ":" - val rawEntity = importedEntity.stripPrefix("./") - val alias = importedAs - val matcher = pathPattern.matcher(rawEntity) - val sep = Matcher.quoteReplacement(JFile.separator) - val root = s"$codeRoot${JFile.separator}" - val currentFile = s"$root$fileName" - // We want to know if the import is local since if an external name is used to match internal methods we may have - // false paths. - val isLocalImport = importedEntity.matches("^[.]+/?.*") - // TODO: At times there is an operation inside of a require, e.g. path.resolve(__dirname + "/../config/env/all.js") - // this tries to recover the string but does not perform string constant propagation - val entity = if matcher.find() then matcher.group(1) else rawEntity - val resolvedPath = better.files - .File(currentFile.stripSuffix(currentFile.split(sep).last), entity.split(pathSep).head) - .pathAsString - .stripPrefix(root) + override protected def optionalResolveImport( + fileName: String, + importCall: Call, + importedEntity: String, + importedAs: String, + diffGraph: DiffGraphBuilder + ): Unit = + val pathSep = ":" + val rawEntity = importedEntity.stripPrefix("./") + val alias = importedAs + val matcher = pathPattern.matcher(rawEntity) + val sep = Matcher.quoteReplacement(JFile.separator) + val root = s"$codeRoot${JFile.separator}" + val currentFile = s"$root$fileName" + // We want to know if the import is local since if an external name is used to match internal methods we may have + // false paths. + val isLocalImport = importedEntity.matches("^[.]+/?.*") + // TODO: At times there is an operation inside of a require, e.g. path.resolve(__dirname + "/../config/env/all.js") + // this tries to recover the string but does not perform string constant propagation + val entity = if matcher.find() then matcher.group(1) else rawEntity + val resolvedPath = better.files + .File(currentFile.stripSuffix(currentFile.split(sep).last), entity.split(pathSep).head) + .pathAsString + .stripPrefix(root) - val isImportingModule = !entity.contains(pathSep) + val isImportingModule = !entity.contains(pathSep) - def targetModule = Try( - if isLocalImport then - cpg - .file(s"${Pattern.quote(resolvedPath)}\\.?.*") - .method - .nameExact(":program") - else - Iterator.empty - ) match - case Failure(_) => - logger.warn(s"Unable to resolve import due to irregular regex at '$importedEntity'") - Iterator.empty - case Success(modules) => modules - - def targetAssignments = targetModule + def targetModule = Try( + if isLocalImport then + cpg + .file(s"${Pattern.quote(resolvedPath)}\\.?.*") + .method .nameExact(":program") - .flatMap(_._callViaContainsOut) - .assignment + else + Iterator.empty + ) match + case Failure(_) => + logger.warn(s"Unable to resolve import due to irregular regex at '$importedEntity'") + Iterator.empty + case Success(modules) => modules + + def targetAssignments = targetModule + .nameExact(":program") + .flatMap(_._callViaContainsOut) + .assignment - val matchingExports = if isImportingModule then - // If we are importing the whole module, we need to load all entities - targetAssignments - .code(s"\\_tmp\\_\\d+\\.\\w+ =.*", "(module\\.)?exports.*") - .dedup - .l - else - // If we are importing a specific entity, then we look for it here - targetAssignments - .code("^(module.)?exports.*") - .where(_.argument.codeExact(alias)) - .dedup - .l + val matchingExports = if isImportingModule then + // If we are importing the whole module, we need to load all entities + targetAssignments + .code(s"\\_tmp\\_\\d+\\.\\w+ =.*", "(module\\.)?exports.*") + .dedup + .l + else + // If we are importing a specific entity, then we look for it here + targetAssignments + .code("^(module.)?exports.*") + .where(_.argument.codeExact(alias)) + .dedup + .l - (if matchingExports.nonEmpty then - matchingExports.flatMap { exp => - exp.argument.l match - case ::(expCall: Call, ::(b: Identifier, _)) - if expCall.code.matches("^(module.)?exports[.]?.*") && b.name == alias => - val moduleMethods = targetModule.repeat(_.astChildren.isMethod)(_.emit).l - lazy val methodMatches = moduleMethods.name(b.name).l - lazy val constructorMatches = - moduleMethods.fullName( - s".*${b.name}$pathSep${XDefines.ConstructorMethodName}$$" - ).l - lazy val moduleExportsThisVariable = moduleMethods.body.local - .where(_.nameExact(b.name)) - .nonEmpty - // Exported function with only the name of the function - val methodPaths = - if methodMatches.nonEmpty then methodMatches.fullName.toSet - else constructorMatches.fullName.toSet - if methodPaths.nonEmpty then - methodPaths.flatMap(x => - cpg.method.fullNameExact(x).newTagNode( - "exported" - ).store()(diffGraph) - Set(ResolvedMethod(x, alias, Option("this")), ResolvedTypeDecl(x)) - ) - else if moduleExportsThisVariable then - Set(ResolvedMember(targetModule.fullName.head, b.name)) - else - Set.empty - case ::(x: Call, ::(b: MethodRef, _)) => - // Exported function with a method ref of the function - val methodName = - x.argumentOption(2).map(_.code).getOrElse(b.referencedMethod.name) - val (callName, receiver) = - if methodName == "exports" then (alias, Option("this")) - else - cpg.method.fullNameExact(methodName).newTagNode( - "exported" - ).store()(diffGraph) - (methodName, Option(alias)) - b.referencedMethod.astParent.iterator - .collectAll[Method] - .fullName - .map(x => ResolvedTypeDecl(x)) - .toSet ++ Set(ResolvedMethod(b.methodFullName, callName, receiver)) - case ::(_, ::(y: Call, _)) => - // Exported closure with a method ref within the AST of the RHS - y.ast.isMethodRef.map(mRef => - cpg.method.fullNameExact(mRef.methodFullName).newTagNode( - "exported" - ).store()(diffGraph) - ResolvedMethod(mRef.methodFullName, alias, Option("this")) - ).toSet - case _ => - Set.empty[ResolvedImport] - end match - }.toSet - else - Set(UnknownMethod(entity, alias, Option("this")), UnknownTypeDecl(entity)) - ).foreach(x => resolvedImportToTag(x, importCall, diffGraph)) - end optionalResolveImport + (if matchingExports.nonEmpty then + matchingExports.flatMap { exp => + exp.argument.l match + case ::(expCall: Call, ::(b: Identifier, _)) + if expCall.code.matches("^(module.)?exports[.]?.*") && b.name == alias => + val moduleMethods = targetModule.repeat(_.astChildren.isMethod)(_.emit).l + lazy val methodMatches = moduleMethods.name(b.name).l + lazy val constructorMatches = + moduleMethods.fullName( + s".*${b.name}$pathSep${XDefines.ConstructorMethodName}$$" + ).l + lazy val moduleExportsThisVariable = moduleMethods.body.local + .where(_.nameExact(b.name)) + .nonEmpty + // Exported function with only the name of the function + val methodPaths = + if methodMatches.nonEmpty then methodMatches.fullName.toSet + else constructorMatches.fullName.toSet + if methodPaths.nonEmpty then + methodPaths.flatMap(x => + cpg.method.fullNameExact(x).newTagNode( + "exported" + ).store()(diffGraph) + Set(ResolvedMethod(x, alias, Option("this")), ResolvedTypeDecl(x)) + ) + else if moduleExportsThisVariable then + Set(ResolvedMember(targetModule.fullName.head, b.name)) + else + Set.empty + case ::(x: Call, ::(b: MethodRef, _)) => + // Exported function with a method ref of the function + val methodName = + x.argumentOption(2).map(_.code).getOrElse(b.referencedMethod.name) + val (callName, receiver) = + if methodName == "exports" then (alias, Option("this")) + else + cpg.method.fullNameExact(methodName).newTagNode( + "exported" + ).store()(diffGraph) + (methodName, Option(alias)) + b.referencedMethod.astParent.iterator + .collectAll[Method] + .fullName + .map(x => ResolvedTypeDecl(x)) + .toSet ++ Set(ResolvedMethod(b.methodFullName, callName, receiver)) + case ::(_, ::(y: Call, _)) => + // Exported closure with a method ref within the AST of the RHS + y.ast.isMethodRef.map(mRef => + cpg.method.fullNameExact(mRef.methodFullName).newTagNode( + "exported" + ).store()(diffGraph) + ResolvedMethod(mRef.methodFullName, alias, Option("this")) + ).toSet + case _ => + Set.empty[ResolvedImport] + end match + }.toSet + else + Set(UnknownMethod(entity, alias, Option("this")), UnknownTypeDecl(entity)) + ).foreach(x => resolvedImportToTag(x, importCall, diffGraph)) + end optionalResolveImport end ImportResolverPass diff --git a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/ImportsPass.scala b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/ImportsPass.scala index ed038eda..308e568e 100644 --- a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/ImportsPass.scala +++ b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/ImportsPass.scala @@ -18,10 +18,10 @@ import io.shiftleft.semanticcpg.language.operatorextension.OpNodes.Assignment */ class ImportsPass(cpg: Cpg) extends XImportsPass(cpg): - override protected val importCallName: String = "require" + override protected val importCallName: String = "require" - override protected def importCallToPart(x: Call): Iterator[(Call, Assignment)] = - x.inAssignment.codeNot("var .*").map(y => (x, y)) + override protected def importCallToPart(x: Call): Iterator[(Call, Assignment)] = + x.inAssignment.codeNot("var .*").map(y => (x, y)) - override protected def importedEntityFromCall(call: Call): String = - X2Cpg.stripQuotes(call.argument(1).code) + override protected def importedEntityFromCall(call: Call): String = + X2Cpg.stripQuotes(call.argument(1).code) diff --git a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/JavaScriptInheritanceNamePass.scala b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/JavaScriptInheritanceNamePass.scala index 64f0ff1f..36f46171 100644 --- a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/JavaScriptInheritanceNamePass.scala +++ b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/JavaScriptInheritanceNamePass.scala @@ -8,6 +8,6 @@ import io.shiftleft.codepropertygraph.Cpg */ class JavaScriptInheritanceNamePass(cpg: Cpg) extends XInheritanceFullNamePass(cpg): - override val pathSep: Char = ':' - override val moduleName: String = ":program" - override val fileExt: String = ".js" + override val pathSep: Char = ':' + override val moduleName: String = ":program" + override val fileExt: String = ".js" diff --git a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/JavaScriptTypeHintCallLinker.scala b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/JavaScriptTypeHintCallLinker.scala index afe37ac9..546def13 100644 --- a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/JavaScriptTypeHintCallLinker.scala +++ b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/JavaScriptTypeHintCallLinker.scala @@ -7,8 +7,8 @@ import io.shiftleft.semanticcpg.language.* class JavaScriptTypeHintCallLinker(cpg: Cpg) extends XTypeHintCallLinker(cpg): - override protected val pathSep = ':' + override protected val pathSep = ':' - override protected def calls: Iterator[Call] = cpg.call - .or(_.nameNot(".*", ".*"), _.name(".new")) - .filter(c => calleeNames(c).nonEmpty && c.callee.isEmpty) + override protected def calls: Iterator[Call] = cpg.call + .or(_.nameNot(".*", ".*"), _.name(".new")) + .filter(c => calleeNames(c).nonEmpty && c.callee.isEmpty) diff --git a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/JavaScriptTypeRecovery.scala b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/JavaScriptTypeRecovery.scala index a1d84109..4bb1c4c6 100644 --- a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/JavaScriptTypeRecovery.scala +++ b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/JavaScriptTypeRecovery.scala @@ -12,22 +12,22 @@ import overflowdb.BatchedUpdate.DiffGraphBuilder class JavaScriptTypeRecoveryPass(cpg: Cpg, config: XTypeRecoveryConfig = XTypeRecoveryConfig()) extends XTypeRecoveryPass[File](cpg, config): - override protected def generateRecoveryPass(state: XTypeRecoveryState): XTypeRecovery[File] = - new JavaScriptTypeRecovery(cpg, state) + override protected def generateRecoveryPass(state: XTypeRecoveryState): XTypeRecovery[File] = + new JavaScriptTypeRecovery(cpg, state) private class JavaScriptTypeRecovery(cpg: Cpg, state: XTypeRecoveryState) extends XTypeRecovery[File](cpg, state): - override def compilationUnit: Iterator[File] = cpg.file.iterator + override def compilationUnit: Iterator[File] = cpg.file.iterator - override def generateRecoveryForCompilationUnitTask( - unit: File, - builder: DiffGraphBuilder - ): RecoverForXCompilationUnit[File] = - val newConfig = state.config.copy(enabledDummyTypes = - state.isFinalIteration && state.config.enabledDummyTypes - ) - new RecoverForJavaScriptFile(cpg, unit, builder, state.copy(config = newConfig)) + override def generateRecoveryForCompilationUnitTask( + unit: File, + builder: DiffGraphBuilder + ): RecoverForXCompilationUnit[File] = + val newConfig = state.config.copy(enabledDummyTypes = + state.isFinalIteration && state.config.enabledDummyTypes + ) + new RecoverForJavaScriptFile(cpg, unit, builder, state.copy(config = newConfig)) private class RecoverForJavaScriptFile( cpg: Cpg, @@ -36,192 +36,192 @@ private class RecoverForJavaScriptFile( state: XTypeRecoveryState ) extends RecoverForXCompilationUnit[File](cpg, cu, builder, state): - private lazy val exportedIdentifiers = cu.method - .nameExact(":program") - .flatMap(_._callViaContainsOut) - .nameExact(Operators.assignment) - .filter(_.code.startsWith("exports.*")) - .argument - .isIdentifier - .name - .toSet - override protected val pathSep = ':' - - /** A heuristic method to determine if a call is a constructor or not. - */ - override protected def isConstructor(c: Call): Boolean = - c.name.endsWith("factory") && c.inCall.astParent.headOption.exists(_.isInstanceOf[Block]) - - override protected def isConstructor(name: String): Boolean = - !name.isBlank && (name.charAt(0).isUpper || name.endsWith("factory")) - - override protected def prepopulateSymbolTableEntry(x: AstNode): Unit = x match - case x @ (_: Identifier | _: Local | _: MethodParameterIn) - if x.property(PropertyNames.TYPE_FULL_NAME, Defines.Any) != Defines.Any => - val typeFullName = x.property(PropertyNames.TYPE_FULL_NAME, Defines.Any) - val typeHints = symbolTable.get(LocalVar(x.property( - PropertyNames.TYPE_FULL_NAME, - Defines.Any - ))) - typeFullName - lazy val cpgTypeFullName = cpg.typeDecl.nameExact(typeFullName).fullName.toSet - val resolvedTypeHints = - if typeHints.nonEmpty then symbolTable.put(x, typeHints) - else if cpgTypeFullName.nonEmpty then symbolTable.put(x, cpgTypeFullName) - else symbolTable.put(x, x.getKnownTypes) - if !resolvedTypeHints.contains(typeFullName) && resolvedTypeHints.sizeIs == 1 then - builder.setNodeProperty(x, PropertyNames.TYPE_FULL_NAME, resolvedTypeHints.head) - case x @ (_: Identifier | _: Local | _: MethodParameterIn) => - symbolTable.put(x, x.getKnownTypes) - case x: Call => symbolTable.put(x, (x.methodFullName +: x.dynamicTypeHintFullName).toSet) - case _ => - - override protected def prepopulateSymbolTable(): Unit = - super.prepopulateSymbolTable() - cu.ast.isMethod.foreach(f => - symbolTable.put(CallAlias(f.name, Option("this")), Set(f.fullName)) - ) - (cu.ast.isParameter.whereNot(_.nameExact("this")) ++ cu.ast.isMethod.methodReturn).filter( - hasTypes - ).foreach { p => - val resolvedHints = p.getKnownTypes - .map { t => - t.split("\\.").headOption match - case Some(base) if symbolTable.contains(LocalVar(base)) => - ( - t, - symbolTable.get(LocalVar(base)).map(x => s"$x${t.stripPrefix(base)}") - ) - case _ => (t, Set(t)) - } - .flatMap { - case (t, ts) if Set(t) == ts => Set(t) - case (_, ts) => ts.map(_.replaceAll("\\.(?!js::program)", pathSep.toString)) - } - p match - case _: MethodParameterIn => symbolTable.put(p, resolvedHints) - case _: MethodReturn if resolvedHints.sizeIs == 1 => - builder.setNodeProperty(p, PropertyNames.TYPE_FULL_NAME, resolvedHints.head) - case _: MethodReturn => - builder.setNodeProperty(p, PropertyNames.TYPE_FULL_NAME, Defines.Any) - builder.setNodeProperty( - p, - PropertyNames.DYNAMIC_TYPE_HINT_FULL_NAME, - resolvedHints + private lazy val exportedIdentifiers = cu.method + .nameExact(":program") + .flatMap(_._callViaContainsOut) + .nameExact(Operators.assignment) + .filter(_.code.startsWith("exports.*")) + .argument + .isIdentifier + .name + .toSet + override protected val pathSep = ':' + + /** A heuristic method to determine if a call is a constructor or not. + */ + override protected def isConstructor(c: Call): Boolean = + c.name.endsWith("factory") && c.inCall.astParent.headOption.exists(_.isInstanceOf[Block]) + + override protected def isConstructor(name: String): Boolean = + !name.isBlank && (name.charAt(0).isUpper || name.endsWith("factory")) + + override protected def prepopulateSymbolTableEntry(x: AstNode): Unit = x match + case x @ (_: Identifier | _: Local | _: MethodParameterIn) + if x.property(PropertyNames.TYPE_FULL_NAME, Defines.Any) != Defines.Any => + val typeFullName = x.property(PropertyNames.TYPE_FULL_NAME, Defines.Any) + val typeHints = symbolTable.get(LocalVar(x.property( + PropertyNames.TYPE_FULL_NAME, + Defines.Any + ))) - typeFullName + lazy val cpgTypeFullName = cpg.typeDecl.nameExact(typeFullName).fullName.toSet + val resolvedTypeHints = + if typeHints.nonEmpty then symbolTable.put(x, typeHints) + else if cpgTypeFullName.nonEmpty then symbolTable.put(x, cpgTypeFullName) + else symbolTable.put(x, x.getKnownTypes) + if !resolvedTypeHints.contains(typeFullName) && resolvedTypeHints.sizeIs == 1 then + builder.setNodeProperty(x, PropertyNames.TYPE_FULL_NAME, resolvedTypeHints.head) + case x @ (_: Identifier | _: Local | _: MethodParameterIn) => + symbolTable.put(x, x.getKnownTypes) + case x: Call => symbolTable.put(x, (x.methodFullName +: x.dynamicTypeHintFullName).toSet) + case _ => + + override protected def prepopulateSymbolTable(): Unit = + super.prepopulateSymbolTable() + cu.ast.isMethod.foreach(f => + symbolTable.put(CallAlias(f.name, Option("this")), Set(f.fullName)) + ) + (cu.ast.isParameter.whereNot(_.nameExact("this")) ++ cu.ast.isMethod.methodReturn).filter( + hasTypes + ).foreach { p => + val resolvedHints = p.getKnownTypes + .map { t => + t.split("\\.").headOption match + case Some(base) if symbolTable.contains(LocalVar(base)) => + ( + t, + symbolTable.get(LocalVar(base)).map(x => s"$x${t.stripPrefix(base)}") ) - case _ => - } - end prepopulateSymbolTable - - override protected def isField(i: Identifier): Boolean = - state.isFieldCache.getOrElseUpdate( - i.id(), - exportedIdentifiers.contains(i.name) || super.isField(i) - ) - - override protected def visitIdentifierAssignedToConstructor( - i: Identifier, - c: Call - ): Set[String] = - val constructorPaths = if c.methodFullName.endsWith(".alloc") then - val newOp = c.inAssignment.astSiblings.isCall.nameExact(".new").headOption - val newChildren = newOp.astChildren.l - val possibleImportIdentifier = newChildren.isIdentifier.headOption match - case Some(i) if GlobalBuiltins.builtins.contains(i.name) => Set(s"__ecma.${i.name}") - case Some(i) => - val typs = symbolTable.get(CallAlias(i.name, Option("this"))) - if typs.nonEmpty then newOp.foreach(symbolTable.put(_, typs)) - symbolTable.get(i) - case None => Set.empty[String] - lazy val possibleConstructorPointer = - newChildren.astChildren.isFieldIdentifier.map(f => - CallAlias(f.canonicalName, Some("this")) - ).headOption match - case Some(fi) => symbolTable.get(fi) - case None => Set.empty[String] - - if possibleImportIdentifier.nonEmpty then possibleImportIdentifier - else if possibleConstructorPointer.nonEmpty then possibleConstructorPointer - else Set.empty[String] - else (symbolTable.get(c) + c.methodFullName).map(t => t.stripSuffix(".factory")) - associateTypes(i, constructorPaths) - end visitIdentifierAssignedToConstructor - - override protected def visitIdentifierAssignedToOperator( - i: Identifier, - c: Call, - operation: String - ): Set[String] = - operation match - case ".new" => - c.astChildren.l match - case ::(fa: Call, ::(i: Identifier, _)) if fa.name == Operators.fieldAccess => - symbolTable.append( - c, - visitIdentifierAssignedToFieldLoad(i, new FieldAccess(fa)).map(t => - s"$t$pathSep$ConstructorMethodName" - ) - ) - case _ => Set.empty - case _ => super.visitIdentifierAssignedToOperator(i, c, operation) - - override protected def associateInterproceduralTypes( - i: Identifier, - fieldFullName: String, - fieldName: String, - globalTypes: Set[String], - baseTypes: Set[String] - ): Set[String] = - if symbolTable.contains(LocalVar(fieldName)) then - val fieldTypes = symbolTable.get(LocalVar(fieldName)) - symbolTable.append(i, fieldTypes) - else if symbolTable.contains(CallAlias(fieldName, Option("this"))) then - symbolTable.get(CallAlias(fieldName, Option("this"))) - else - super.associateInterproceduralTypes( - i: Identifier, - fieldFullName: String, - fieldName: String, - globalTypes: Set[String], - baseTypes: Set[String] + case _ => (t, Set(t)) + } + .flatMap { + case (t, ts) if Set(t) == ts => Set(t) + case (_, ts) => ts.map(_.replaceAll("\\.(?!js::program)", pathSep.toString)) + } + p match + case _: MethodParameterIn => symbolTable.put(p, resolvedHints) + case _: MethodReturn if resolvedHints.sizeIs == 1 => + builder.setNodeProperty(p, PropertyNames.TYPE_FULL_NAME, resolvedHints.head) + case _: MethodReturn => + builder.setNodeProperty(p, PropertyNames.TYPE_FULL_NAME, Defines.Any) + builder.setNodeProperty( + p, + PropertyNames.DYNAMIC_TYPE_HINT_FULL_NAME, + resolvedHints ) - - override protected def visitIdentifierAssignedToCall(i: Identifier, c: Call): Set[String] = - // Instead of returning empty, this must visit and identify the default export - if c.name == "require" then Set.empty - else super.visitIdentifierAssignedToCall(i, c) - - override protected def visitIdentifierAssignedToMethodRef( - i: Identifier, - m: MethodRef, - rec: Option[String] = None - ): Set[String] = - super.visitIdentifierAssignedToMethodRef(i, m, Option("this")) - - override protected def visitIdentifierAssignedToTypeRef( - i: Identifier, - t: TypeRef, - rec: Option[String] = None - ): Set[String] = - super.visitIdentifierAssignedToTypeRef(i, t, Option("this")) - - override protected def postSetTypeInformation(): Unit = - // often there are "this" identifiers with type hints but this can be set to a type hint if they meet the criteria - cu.method - .flatMap(_._identifierViaContainsOut) - .nameExact("this") - .where(_.typeFullNameExact(Defines.Any)) - .filterNot(_.dynamicTypeHintFullName.isEmpty) - .foreach(setTypeFromTypeHints) - - protected override def storeIdentifierTypeInfo(i: Identifier, types: Seq[String]): Unit = - super.storeIdentifierTypeInfo( - i, - types.map(_.stripSuffix(s"$pathSep${XDefines.ConstructorMethodName}")) + case _ => + } + end prepopulateSymbolTable + + override protected def isField(i: Identifier): Boolean = + state.isFieldCache.getOrElseUpdate( + i.id(), + exportedIdentifiers.contains(i.name) || super.isField(i) + ) + + override protected def visitIdentifierAssignedToConstructor( + i: Identifier, + c: Call + ): Set[String] = + val constructorPaths = if c.methodFullName.endsWith(".alloc") then + val newOp = c.inAssignment.astSiblings.isCall.nameExact(".new").headOption + val newChildren = newOp.astChildren.l + val possibleImportIdentifier = newChildren.isIdentifier.headOption match + case Some(i) if GlobalBuiltins.builtins.contains(i.name) => Set(s"__ecma.${i.name}") + case Some(i) => + val typs = symbolTable.get(CallAlias(i.name, Option("this"))) + if typs.nonEmpty then newOp.foreach(symbolTable.put(_, typs)) + symbolTable.get(i) + case None => Set.empty[String] + lazy val possibleConstructorPointer = + newChildren.astChildren.isFieldIdentifier.map(f => + CallAlias(f.canonicalName, Some("this")) + ).headOption match + case Some(fi) => symbolTable.get(fi) + case None => Set.empty[String] + + if possibleImportIdentifier.nonEmpty then possibleImportIdentifier + else if possibleConstructorPointer.nonEmpty then possibleConstructorPointer + else Set.empty[String] + else (symbolTable.get(c) + c.methodFullName).map(t => t.stripSuffix(".factory")) + associateTypes(i, constructorPaths) + end visitIdentifierAssignedToConstructor + + override protected def visitIdentifierAssignedToOperator( + i: Identifier, + c: Call, + operation: String + ): Set[String] = + operation match + case ".new" => + c.astChildren.l match + case ::(fa: Call, ::(i: Identifier, _)) if fa.name == Operators.fieldAccess => + symbolTable.append( + c, + visitIdentifierAssignedToFieldLoad(i, new FieldAccess(fa)).map(t => + s"$t$pathSep$ConstructorMethodName" + ) + ) + case _ => Set.empty + case _ => super.visitIdentifierAssignedToOperator(i, c, operation) + + override protected def associateInterproceduralTypes( + i: Identifier, + fieldFullName: String, + fieldName: String, + globalTypes: Set[String], + baseTypes: Set[String] + ): Set[String] = + if symbolTable.contains(LocalVar(fieldName)) then + val fieldTypes = symbolTable.get(LocalVar(fieldName)) + symbolTable.append(i, fieldTypes) + else if symbolTable.contains(CallAlias(fieldName, Option("this"))) then + symbolTable.get(CallAlias(fieldName, Option("this"))) + else + super.associateInterproceduralTypes( + i: Identifier, + fieldFullName: String, + fieldName: String, + globalTypes: Set[String], + baseTypes: Set[String] ) - protected override def storeLocalTypeInfo(i: Local, types: Seq[String]): Unit = - super.storeLocalTypeInfo( - i, - types.map(_.stripSuffix(s"$pathSep${XDefines.ConstructorMethodName}")) - ) + override protected def visitIdentifierAssignedToCall(i: Identifier, c: Call): Set[String] = + // Instead of returning empty, this must visit and identify the default export + if c.name == "require" then Set.empty + else super.visitIdentifierAssignedToCall(i, c) + + override protected def visitIdentifierAssignedToMethodRef( + i: Identifier, + m: MethodRef, + rec: Option[String] = None + ): Set[String] = + super.visitIdentifierAssignedToMethodRef(i, m, Option("this")) + + override protected def visitIdentifierAssignedToTypeRef( + i: Identifier, + t: TypeRef, + rec: Option[String] = None + ): Set[String] = + super.visitIdentifierAssignedToTypeRef(i, t, Option("this")) + + override protected def postSetTypeInformation(): Unit = + // often there are "this" identifiers with type hints but this can be set to a type hint if they meet the criteria + cu.method + .flatMap(_._identifierViaContainsOut) + .nameExact("this") + .where(_.typeFullNameExact(Defines.Any)) + .filterNot(_.dynamicTypeHintFullName.isEmpty) + .foreach(setTypeFromTypeHints) + + protected override def storeIdentifierTypeInfo(i: Identifier, types: Seq[String]): Unit = + super.storeIdentifierTypeInfo( + i, + types.map(_.stripSuffix(s"$pathSep${XDefines.ConstructorMethodName}")) + ) + + protected override def storeLocalTypeInfo(i: Local, types: Seq[String]): Unit = + super.storeLocalTypeInfo( + i, + types.map(_.stripSuffix(s"$pathSep${XDefines.ConstructorMethodName}")) + ) end RecoverForJavaScriptFile diff --git a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/JsMetaDataPass.scala b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/JsMetaDataPass.scala index 8a850409..064af144 100644 --- a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/JsMetaDataPass.scala +++ b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/JsMetaDataPass.scala @@ -8,9 +8,9 @@ import io.shiftleft.passes.CpgPass class JsMetaDataPass(cpg: Cpg, hash: String, inputPath: String) extends CpgPass(cpg): - override def run(diffGraph: DiffGraphBuilder): Unit = - val absolutePathToRoot = File(inputPath).path.toAbsolutePath.toString - val metaNode = NewMetaData().language(Languages.JSSRC).root(absolutePathToRoot).hash( - hash - ).version("0.1") - diffGraph.addNode(metaNode) + override def run(diffGraph: DiffGraphBuilder): Unit = + val absolutePathToRoot = File(inputPath).path.toAbsolutePath.toString + val metaNode = NewMetaData().language(Languages.JSSRC).root(absolutePathToRoot).hash( + hash + ).version("0.1") + diffGraph.addNode(metaNode) diff --git a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/PrivateKeyFilePass.scala b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/PrivateKeyFilePass.scala index 1e2a355b..532c3fdb 100644 --- a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/PrivateKeyFilePass.scala +++ b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/PrivateKeyFilePass.scala @@ -11,15 +11,15 @@ import scala.util.matching.Regex class PrivateKeyFilePass(cpg: Cpg, config: Config, report: Report = new Report()) extends ConfigPass(cpg, config, report): - private val PrivateKeyRegex: Regex = """.*RSA\sPRIVATE\sKEY.*""".r + private val PrivateKeyRegex: Regex = """.*RSA\sPRIVATE\sKEY.*""".r - override val allExtensions: Set[String] = Set(".key") - override val selectedExtensions: Set[String] = Set(".key") + override val allExtensions: Set[String] = Set(".key") + override val selectedExtensions: Set[String] = Set(".key") - override def fileContent(file: File): Seq[String] = - Seq("Content omitted for security reasons.") + override def fileContent(file: File): Seq[String] = + Seq("Content omitted for security reasons.") - override def generateParts(): Array[File] = - configFiles(config, selectedExtensions).toArray.filter(p => - IOUtils.readLinesInFile(p.path).exists(PrivateKeyRegex.matches) - ) + override def generateParts(): Array[File] = + configFiles(config, selectedExtensions).toArray.filter(p => + IOUtils.readLinesInFile(p.path).exists(PrivateKeyRegex.matches) + ) diff --git a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/TypeNodePass.scala b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/TypeNodePass.scala index d761278c..dd761034 100644 --- a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/TypeNodePass.scala +++ b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/passes/TypeNodePass.scala @@ -5,12 +5,12 @@ import io.shiftleft.codepropertygraph.generated.nodes.NewType import io.shiftleft.passes.CpgPass class TypeNodePass(usedTypes: List[(String, String)], cpg: Cpg) extends CpgPass(cpg, "types"): - override def run(diffGraph: DiffGraphBuilder): Unit = - val filteredTypes = usedTypes.filterNot { case (name, _) => - name == Defines.Any || Defines.JsTypes.contains(name) - } + override def run(diffGraph: DiffGraphBuilder): Unit = + val filteredTypes = usedTypes.filterNot { case (name, _) => + name == Defines.Any || Defines.JsTypes.contains(name) + } - filteredTypes.sortBy(_._2).foreach { case (name, fullName) => - val typeNode = NewType().name(name).fullName(fullName).typeDeclFullName(fullName) - diffGraph.addNode(typeNode) - } + filteredTypes.sortBy(_._2).foreach { case (name, fullName) => + val typeNode = NewType().name(name).fullName(fullName).typeDeclFullName(fullName) + diffGraph.addNode(typeNode) + } diff --git a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/preprocessing/EjsPreprocessor.scala b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/preprocessing/EjsPreprocessor.scala index a84bd7c7..08c0a87a 100644 --- a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/preprocessing/EjsPreprocessor.scala +++ b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/preprocessing/EjsPreprocessor.scala @@ -4,81 +4,82 @@ import scala.collection.mutable class EjsPreprocessor: - private val CommentTag = "<%#" - private val TagGroupsRegex = """(<%[=\-_#]?)([\s\S]*?)([-_#]?%>)""".r - private val ScriptGroupsRegex = """()""".r - private val Tags = List("<%#", "<%=", "<%-", "<%_", "-%>", "_%>", "#%>", "%>") + private val CommentTag = "<%#" + private val TagGroupsRegex = """(<%[=\-_#]?)([\s\S]*?)([-_#]?%>)""".r + private val ScriptGroupsRegex = """()""".r + private val Tags = List("<%#", "<%=", "<%-", "<%_", "-%>", "_%>", "#%>", "%>") - private def stripScriptTag(code: String): String = - var x = code.replaceAll("", "%> ") - ScriptGroupsRegex.findAllIn(code).matchData.foreach { ma => - var scriptBlock = ma.group(2) - val matches = TagGroupsRegex.findAllIn(scriptBlock).matchData.toList - matches.foreach { - case mat if mat.group(1) == "<%" && mat.group(3) == "-%>" => - scriptBlock = scriptBlock.replace( - mat.toString(), - " " * mat.toString().replaceAll("[^\\s]", " ").length - ) - case _ => - } - Tags.foreach { tag => - scriptBlock = scriptBlock.replaceAll(tag, " " * tag.length) - } - x = x.replace(ma.group(2), scriptBlock) - } - x - end stripScriptTag + private def stripScriptTag(code: String): String = + var x = code.replaceAll("", "%> ") + ScriptGroupsRegex.findAllIn(code).matchData.foreach { ma => + var scriptBlock = ma.group(2) + val matches = TagGroupsRegex.findAllIn(scriptBlock).matchData.toList + matches.foreach { + case mat if mat.group(1) == "<%" && mat.group(3) == "-%>" => + scriptBlock = scriptBlock.replace( + mat.toString(), + " " * mat.toString().replaceAll("[^\\s]", " ").length + ) + case _ => + } + Tags.foreach { tag => + scriptBlock = scriptBlock.replaceAll(tag, " " * tag.length) + } + x = x.replace(ma.group(2), scriptBlock) + } + x + end stripScriptTag - private def needsSemicolon(code: String): Boolean = - !code.trim.endsWith("{") && !code.trim.endsWith("}") && !code.trim.endsWith(";") + private def needsSemicolon(code: String): Boolean = + !code.trim.endsWith("{") && !code.trim.endsWith("}") && !code.trim.endsWith(";") - def preprocess(code: String): String = - val codeWithoutScriptTag = stripScriptTag(code) - val codeAsCharArray = codeWithoutScriptTag.toCharArray - val preprocessedCode = new mutable.StringBuilder(codeAsCharArray.length) - val matches = TagGroupsRegex.findAllIn(codeWithoutScriptTag).matchData.toList + def preprocess(code: String): String = + val codeWithoutScriptTag = stripScriptTag(code) + val codeAsCharArray = codeWithoutScriptTag.toCharArray + val preprocessedCode = new mutable.StringBuilder(codeAsCharArray.length) + val matches = TagGroupsRegex.findAllIn(codeWithoutScriptTag).matchData.toList - val positions = matches.flatMap { - case ma if ma.group(1) == CommentTag => None // ignore comments - case ma if ma.group(2).trim.startsWith("include ") => - None // ignore including other ejs templates - case ma => - val start = ma.start + ma.group(1).length - val end = ma.end - ma.group(3).length - Option((start, end)) - } + val positions = matches.flatMap { + case ma if ma.group(1) == CommentTag => None // ignore comments + case ma if ma.group(2).trim.startsWith("include ") => + None // ignore including other ejs templates + case ma => + val start = ma.start + ma.group(1).length + val end = ma.end - ma.group(3).length + Option((start, end)) + } - codeAsCharArray.zipWithIndex.foreach { - case (currChar, _) if currChar == '\n' || currChar == '\r' => - preprocessedCode.append(currChar) - case (currChar, index) if positions.exists { case (start, end) => - index >= start && index < end - } => - preprocessedCode.append(currChar) - case _ => - preprocessedCode.append(" ") - } + codeAsCharArray.zipWithIndex.foreach { + case (currChar, _) if currChar == '\n' || currChar == '\r' => + preprocessedCode.append(currChar) + case (currChar, index) if positions.exists { case (start, end) => + index >= start && index < end + } => + preprocessedCode.append(currChar) + case _ => + preprocessedCode.append(" ") + } - var codeWithoutSemicolon = preprocessedCode.toString() - val alreadyReplaced = mutable.ArrayBuffer.empty[(Int, Int)] - matches.foreach { - case ma if ma.group(1) == CommentTag => // ignore comments - case ma - if ma.group(2).trim.startsWith( - "include " - ) => // ignore including other ejs templates - case ma if needsSemicolon(ma.group(2)) => - val start = ma.start + ma.group(1).length - val end = ma.end - ma.group(3).length - if !alreadyReplaced.contains((start, end)) then - val replacementCode = s"${ma.group(2)};" - codeWithoutSemicolon = - s"${codeWithoutSemicolon.substring(0, start)}$replacementCode${codeWithoutSemicolon.substring(end + 1, codeWithoutSemicolon.length)}" - alreadyReplaced.append((start, end)) - case _ => // others are fine already - } + var codeWithoutSemicolon = preprocessedCode.toString() + val alreadyReplaced = mutable.ArrayBuffer.empty[(Int, Int)] + matches.foreach { + case ma if ma.group(1) == CommentTag => // ignore comments + case ma + if ma.group(2).trim.startsWith( + "include " + ) => // ignore including other ejs templates + case ma if needsSemicolon(ma.group(2)) => + val start = ma.start + ma.group(1).length + val end = ma.end - ma.group(3).length + if !alreadyReplaced.contains((start, end)) then + val replacementCode = s"${ma.group(2)};" + codeWithoutSemicolon = + s"${codeWithoutSemicolon.substring(0, start)}$replacementCode${codeWithoutSemicolon + .substring(end + 1, codeWithoutSemicolon.length)}" + alreadyReplaced.append((start, end)) + case _ => // others are fine already + } - codeWithoutSemicolon - end preprocess + codeWithoutSemicolon + end preprocess end EjsPreprocessor diff --git a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/utils/AstGenRunner.scala b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/utils/AstGenRunner.scala index 41246597..c094b058 100644 --- a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/utils/AstGenRunner.scala +++ b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/utils/AstGenRunner.scala @@ -18,269 +18,269 @@ import scala.util.Try object AstGenRunner: - private val logger = LoggerFactory.getLogger(getClass) - - private val LineLengthThreshold: Int = 10000 - - private val TypeDefinitionFileExtensions = List(".t.ts", ".d.ts") - - private val MinifiedPathRegex: Regex = ".*([.-]min\\..*js|bundle\\.js)".r - - private val IgnoredTestsRegex: Seq[Regex] = - List( - ".*[.-]spec\\.js".r, - ".*[.-]mock\\.js".r, - ".*[.-]e2e\\.js".r, - ".*[.-]test\\.js".r, - ".*cypress\\.json".r, - ".*test.*\\.json".r - ) + private val logger = LoggerFactory.getLogger(getClass) + + private val LineLengthThreshold: Int = 10000 + + private val TypeDefinitionFileExtensions = List(".t.ts", ".d.ts") + + private val MinifiedPathRegex: Regex = ".*([.-]min\\..*js|bundle\\.js)".r + + private val IgnoredTestsRegex: Seq[Regex] = + List( + ".*[.-]spec\\.js".r, + ".*[.-]mock\\.js".r, + ".*[.-]e2e\\.js".r, + ".*[.-]test\\.js".r, + ".*cypress\\.json".r, + ".*test.*\\.json".r + ) + + private val IgnoredFilesRegex: Seq[Regex] = List( + ".*jest\\.config.*".r, + ".*webpack\\..*\\.js".r, + ".*vue\\.config\\.js".r, + ".*babel\\.config\\.js".r, + ".*chunk-vendors.*\\.js".r, // commonly found in webpack / vue.js projects + ".*app~.*\\.js".r, // commonly found in webpack / vue.js projects + ".*\\.chunk\\.js".r, + ".*\\.babelrc.*".r, + ".*\\.eslint.*".r, + ".*\\.tslint.*".r, + ".*\\.stylelintrc\\.js".r, + ".*rollup\\.config.*".r, + ".*\\.types\\.js".r, + ".*\\.cjs\\.js".r, + ".*eslint-local-rules\\.js".r, + ".*\\.devcontainer\\.json".r, + ".*Gruntfile\\.js".r, + ".*i18n.*\\.json".r + ) + + case class AstGenRunnerResult( + parsedFiles: List[(String, String)] = List.empty, + skippedFiles: List[(String, String)] = List.empty + ) + + lazy private val executableName = Environment.operatingSystem match + case Environment.OperatingSystemType.Windows => "astgen-win.exe" + case Environment.OperatingSystemType.Linux => "astgen-linux" + case Environment.OperatingSystemType.Mac => + Environment.architecture match + case Environment.ArchitectureType.X86 => "astgen-macos" + case Environment.ArchitectureType.ARM => "astgen-macos-arm" + case Environment.OperatingSystemType.Unknown => + logger.warn("Could not detect OS version! Defaulting to 'Linux'.") + "astgen-linux" + + lazy private val executableDir: String = + val dir = getClass.getProtectionDomain.getCodeSource.getLocation.toString + val indexOfLib = dir.lastIndexOf("lib") + val fixedDir = if indexOfLib != -1 then + new java.io.File(dir.substring("file:".length, indexOfLib)).toString + else + val indexOfTarget = dir.lastIndexOf("target") + if indexOfTarget != -1 then + new java.io.File(dir.substring("file:".length, indexOfTarget)).toString + else + "." + Paths.get(fixedDir, "/bin/astgen").toAbsolutePath.toString + + private def hasCompatibleAstGenVersion(astGenVersion: String): Boolean = + ExternalCommand.run("astgen --version", ".").toOption.map(_.mkString.strip()) match + case Some(installedVersion) + if installedVersion != "unknown" && + Try(VersionHelper.compare(installedVersion, astGenVersion)).toOption.getOrElse( + -1 + ) >= 0 => + logger.debug(s"Using local astgen v$installedVersion from systems PATH") + true + case Some(installedVersion) => + logger.debug( + s"Found local astgen v$installedVersion in systems PATH but jssrc2cpg requires at least v$astGenVersion" + ) + false + case _ => false - private val IgnoredFilesRegex: Seq[Regex] = List( - ".*jest\\.config.*".r, - ".*webpack\\..*\\.js".r, - ".*vue\\.config\\.js".r, - ".*babel\\.config\\.js".r, - ".*chunk-vendors.*\\.js".r, // commonly found in webpack / vue.js projects - ".*app~.*\\.js".r, // commonly found in webpack / vue.js projects - ".*\\.chunk\\.js".r, - ".*\\.babelrc.*".r, - ".*\\.eslint.*".r, - ".*\\.tslint.*".r, - ".*\\.stylelintrc\\.js".r, - ".*rollup\\.config.*".r, - ".*\\.types\\.js".r, - ".*\\.cjs\\.js".r, - ".*eslint-local-rules\\.js".r, - ".*\\.devcontainer\\.json".r, - ".*Gruntfile\\.js".r, - ".*i18n.*\\.json".r - ) - - case class AstGenRunnerResult( - parsedFiles: List[(String, String)] = List.empty, - skippedFiles: List[(String, String)] = List.empty - ) - - lazy private val executableName = Environment.operatingSystem match - case Environment.OperatingSystemType.Windows => "astgen-win.exe" - case Environment.OperatingSystemType.Linux => "astgen-linux" - case Environment.OperatingSystemType.Mac => - Environment.architecture match - case Environment.ArchitectureType.X86 => "astgen-macos" - case Environment.ArchitectureType.ARM => "astgen-macos-arm" - case Environment.OperatingSystemType.Unknown => - logger.warn("Could not detect OS version! Defaulting to 'Linux'.") - "astgen-linux" - - lazy private val executableDir: String = - val dir = getClass.getProtectionDomain.getCodeSource.getLocation.toString - val indexOfLib = dir.lastIndexOf("lib") - val fixedDir = if indexOfLib != -1 then - new java.io.File(dir.substring("file:".length, indexOfLib)).toString - else - val indexOfTarget = dir.lastIndexOf("target") - if indexOfTarget != -1 then - new java.io.File(dir.substring("file:".length, indexOfTarget)).toString - else - "." - Paths.get(fixedDir, "/bin/astgen").toAbsolutePath.toString - - private def hasCompatibleAstGenVersion(astGenVersion: String): Boolean = - ExternalCommand.run("astgen --version", ".").toOption.map(_.mkString.strip()) match - case Some(installedVersion) - if installedVersion != "unknown" && - Try(VersionHelper.compare(installedVersion, astGenVersion)).toOption.getOrElse( - -1 - ) >= 0 => - logger.debug(s"Using local astgen v$installedVersion from systems PATH") - true - case Some(installedVersion) => - logger.debug( - s"Found local astgen v$installedVersion in systems PATH but jssrc2cpg requires at least v$astGenVersion" - ) - false - case _ => false - - private lazy val astGenCommand = - val conf = ConfigFactory.load - val astGenVersion = conf.getString("jssrc2cpg.astgen_version") - if hasCompatibleAstGenVersion(astGenVersion) then - "astgen" - else - s"$executableDir/$executableName" + private lazy val astGenCommand = + val conf = ConfigFactory.load + val astGenVersion = conf.getString("jssrc2cpg.astgen_version") + if hasCompatibleAstGenVersion(astGenVersion) then + "astgen" + else + s"$executableDir/$executableName" end AstGenRunner class AstGenRunner(config: Config): - import AstGenRunner.* - - private val executableArgs = if !config.tsTypes then " --no-tsTypes" else "" - - private def skippedFiles(in: File, astGenOut: List[String]): List[String] = - val skipped = astGenOut.collect { - case out if !out.startsWith("Converted") && !out.startsWith("Retrieving") => - val filename = out.substring(0, out.indexOf(" ")) - val reason = out.substring(out.indexOf(" ") + 1) - logger.warn(s"\t- failed to parse '${in / filename}': '$reason'") - Option(filename) - case out => - logger.debug(s"\t+ $out") - None - } - skipped.flatten - - private def isIgnoredByUserConfig(filePath: String): Boolean = - lazy val isInIgnoredFiles = config.ignoredFiles.exists { - case ignorePath if File(ignorePath).isDirectory => filePath.startsWith(ignorePath) - case ignorePath => filePath == ignorePath - } - lazy val isInIgnoredFileRegex = config.ignoredFilesRegex.matches(filePath) - if isInIgnoredFiles || isInIgnoredFileRegex then - logger.debug(s"'$filePath' ignored by user configuration") - true - else + import AstGenRunner.* + + private val executableArgs = if !config.tsTypes then " --no-tsTypes" else "" + + private def skippedFiles(in: File, astGenOut: List[String]): List[String] = + val skipped = astGenOut.collect { + case out if !out.startsWith("Converted") && !out.startsWith("Retrieving") => + val filename = out.substring(0, out.indexOf(" ")) + val reason = out.substring(out.indexOf(" ") + 1) + logger.warn(s"\t- failed to parse '${in / filename}': '$reason'") + Option(filename) + case out => + logger.debug(s"\t+ $out") + None + } + skipped.flatten + + private def isIgnoredByUserConfig(filePath: String): Boolean = + lazy val isInIgnoredFiles = config.ignoredFiles.exists { + case ignorePath if File(ignorePath).isDirectory => filePath.startsWith(ignorePath) + case ignorePath => filePath == ignorePath + } + lazy val isInIgnoredFileRegex = config.ignoredFilesRegex.matches(filePath) + if isInIgnoredFiles || isInIgnoredFileRegex then + logger.debug(s"'$filePath' ignored by user configuration") + true + else + false + + private def isMinifiedFile(filePath: String): Boolean = filePath match + case p if MinifiedPathRegex.matches(p) => true + case p if File(p).exists && p.endsWith(".js") => + val lines = IOUtils.readLinesInFile(File(filePath).path) + val linesOfCode = lines.size + val longestLineLength = if lines.isEmpty then 0 else lines.map(_.length).max + if longestLineLength >= LineLengthThreshold && linesOfCode <= 50 then + logger.debug( + s"'$filePath' seems to be a minified file (contains a line with length $longestLineLength)" + ) + true + else false + case _ => false + + private def isIgnoredByDefault(filePath: String): Boolean = + lazy val isIgnored = IgnoredFilesRegex.exists(_.matches(filePath)) + lazy val isIgnoredTest = IgnoredTestsRegex.exists(_.matches(filePath)) + lazy val isMinified = isMinifiedFile(filePath) + if isIgnored || isIgnoredTest || isMinified then + logger.debug(s"'$filePath' ignored by default") + true + else + false + + private def isTranspiledFile(filePath: String): Boolean = + val file = File(filePath) + // We ignore files iff: + // - they are *.js files and + // - they contain a //sourceMappingURL comment or have an associated source map file and + // - a file with the same name is located directly next to them + lazy val isJsFile = file.exists && file.extension.contains(".js") + lazy val hasSourceMapComment = + IOUtils.readLinesInFile(file.path).exists(_.contains("//sourceMappingURL")) + lazy val hasSourceMapFile = File(s"$filePath.map").exists + lazy val hasSourceMap = hasSourceMapComment || hasSourceMapFile + lazy val hasFileWithSameName = + file.siblings.exists(_.nameWithoutExtension(includeAll = false - - private def isMinifiedFile(filePath: String): Boolean = filePath match - case p if MinifiedPathRegex.matches(p) => true - case p if File(p).exists && p.endsWith(".js") => - val lines = IOUtils.readLinesInFile(File(filePath).path) - val linesOfCode = lines.size - val longestLineLength = if lines.isEmpty then 0 else lines.map(_.length).max - if longestLineLength >= LineLengthThreshold && linesOfCode <= 50 then - logger.debug( - s"'$filePath' seems to be a minified file (contains a line with length $longestLineLength)" - ) - true - else false - case _ => false - - private def isIgnoredByDefault(filePath: String): Boolean = - lazy val isIgnored = IgnoredFilesRegex.exists(_.matches(filePath)) - lazy val isIgnoredTest = IgnoredTestsRegex.exists(_.matches(filePath)) - lazy val isMinified = isMinifiedFile(filePath) - if isIgnored || isIgnoredTest || isMinified then - logger.debug(s"'$filePath' ignored by default") - true - else - false - - private def isTranspiledFile(filePath: String): Boolean = - val file = File(filePath) - // We ignore files iff: - // - they are *.js files and - // - they contain a //sourceMappingURL comment or have an associated source map file and - // - a file with the same name is located directly next to them - lazy val isJsFile = file.exists && file.extension.contains(".js") - lazy val hasSourceMapComment = - IOUtils.readLinesInFile(file.path).exists(_.contains("//sourceMappingURL")) - lazy val hasSourceMapFile = File(s"$filePath.map").exists - lazy val hasSourceMap = hasSourceMapComment || hasSourceMapFile - lazy val hasFileWithSameName = - file.siblings.exists(_.nameWithoutExtension(includeAll = - false - ) == file.nameWithoutExtension) - if isJsFile && hasSourceMap && hasFileWithSameName then - logger.debug( - s"'$filePath' ignored by default (seems to be the result of transpilation)" - ) + ) == file.nameWithoutExtension) + if isJsFile && hasSourceMap && hasFileWithSameName then + logger.debug( + s"'$filePath' ignored by default (seems to be the result of transpilation)" + ) + true + else + false + end isTranspiledFile + + private def filterFiles(files: List[String], out: File): List[String] = + files.filter { file => + file.stripSuffix(".json").replace(out.pathAsString, config.inputPath) match + // We are not interested in JS / TS type definition files at this stage. + // TODO: maybe we can enable that later on and use the type definitions there + // for enhancing the CPG with additional type information for functions + case filePath if TypeDefinitionFileExtensions.exists(filePath.endsWith) => false + case filePath if isIgnoredByUserConfig(filePath) => false + case filePath if isIgnoredByDefault(filePath) => false + case filePath if isTranspiledFile(filePath) => false + case _ => true + } + + /** Changes the file-extension by renaming this file; if file does not have an extension, it adds + * the extension. If file does not exist (or is a directory) no change is done and the current + * file is returned. + */ + private def changeExtensionTo(file: File, extension: String): File = + val newName = + s"${file.nameWithoutExtension(includeAll = false)}.${extension.stripPrefix(".")}" + if file.isRegularFile then file.renameTo(newName) + else if file.notExists then File(newName) + else file + + private def processEjsFiles(in: File, out: File, ejsFiles: List[String]): Try[Seq[String]] = + val tmpJsFiles = ejsFiles.map { ejsFilePath => + val ejsFile = File(ejsFilePath) + val maybeTranspiledFile = File(s"${ejsFilePath.stripSuffix(".ejs")}.js") + if isTranspiledFile(maybeTranspiledFile.pathAsString) then + maybeTranspiledFile + else + val sourceFileContent = IOUtils.readEntireFile(ejsFile.path) + val preprocessContent = new EjsPreprocessor().preprocess(sourceFileContent) + (out / in.relativize(ejsFile).toString).parent.createDirectoryIfNotExists(createParents = true - else - false - end isTranspiledFile - - private def filterFiles(files: List[String], out: File): List[String] = - files.filter { file => - file.stripSuffix(".json").replace(out.pathAsString, config.inputPath) match - // We are not interested in JS / TS type definition files at this stage. - // TODO: maybe we can enable that later on and use the type definitions there - // for enhancing the CPG with additional type information for functions - case filePath if TypeDefinitionFileExtensions.exists(filePath.endsWith) => false - case filePath if isIgnoredByUserConfig(filePath) => false - case filePath if isIgnoredByDefault(filePath) => false - case filePath if isTranspiledFile(filePath) => false - case _ => true - } - - /** Changes the file-extension by renaming this file; if file does not have an extension, it - * adds the extension. If file does not exist (or is a directory) no change is done and the - * current file is returned. - */ - private def changeExtensionTo(file: File, extension: String): File = - val newName = - s"${file.nameWithoutExtension(includeAll = false)}.${extension.stripPrefix(".")}" - if file.isRegularFile then file.renameTo(newName) - else if file.notExists then File(newName) - else file - - private def processEjsFiles(in: File, out: File, ejsFiles: List[String]): Try[Seq[String]] = - val tmpJsFiles = ejsFiles.map { ejsFilePath => - val ejsFile = File(ejsFilePath) - val maybeTranspiledFile = File(s"${ejsFilePath.stripSuffix(".ejs")}.js") - if isTranspiledFile(maybeTranspiledFile.pathAsString) then - maybeTranspiledFile - else - val sourceFileContent = IOUtils.readEntireFile(ejsFile.path) - val preprocessContent = new EjsPreprocessor().preprocess(sourceFileContent) - (out / in.relativize(ejsFile).toString).parent.createDirectoryIfNotExists( - createParents = true - ) - val newEjsFile = ejsFile.copyTo(out / in.relativize(ejsFile).toString) - val jsFile = changeExtensionTo(newEjsFile, ".js").writeText(preprocessContent) - newEjsFile.createFile().writeText(sourceFileContent) - jsFile - } - - val result = - ExternalCommand.run(s"$astGenCommand$executableArgs -t ts -o $out", out.toString()) - - val jsons = SourceFiles.determine(out.toString(), Set(".json")) - jsons.foreach { jsonPath => - val jsonFile = File(jsonPath) - val jsonContent = IOUtils.readEntireFile(jsonFile.path) - val json = ujson.read(jsonContent) - val fileName = json("fullName").str - val newFileName = fileName.patch(fileName.lastIndexOf(".js"), ".ejs", 3) - json("relativeName") = newFileName - json("fullName") = newFileName - jsonFile.writeText(json.toString()) - } - - tmpJsFiles.foreach(_.delete()) - result - end processEjsFiles - - private def ejsFiles(in: File, out: File): Try[Seq[String]] = - val files = SourceFiles.determine(in.pathAsString, Set(".ejs")) - if files.nonEmpty then processEjsFiles(in, out, files) - else Success(Seq.empty) - - private def vueFiles(in: File, out: File): Try[Seq[String]] = - val files = SourceFiles.determine(in.pathAsString, Set(".vue")) - if files.nonEmpty then - ExternalCommand.run(s"$astGenCommand$executableArgs -t vue -o $out", in.toString()) - else Success(Seq.empty) - - private def jsFiles(in: File, out: File): Try[Seq[String]] = - ExternalCommand.run(s"$astGenCommand$executableArgs -t ts -o $out", in.toString()) - - private def runAstGenNative(in: File, out: File): Try[Seq[String]] = - for - ejsResult <- ejsFiles(in, out) - vueResult <- vueFiles(in, out) - jsResult <- jsFiles(in, out) - yield jsResult ++ vueResult ++ ejsResult - - def execute(out: File): AstGenRunnerResult = - val in = File(config.inputPath) - logger.debug(s"Running astgen in '$in' ...") - runAstGenNative(in, out) match - case Success(result) => - val parsed = filterFiles(SourceFiles.determine(out.toString(), Set(".json")), out) - val skipped = skippedFiles(in, result.toList) - AstGenRunnerResult(parsed.map((in.toString(), _)), skipped.map((in.toString(), _))) - case Failure(f) => - logger.debug("\t- running astgen failed!", f) - AstGenRunnerResult() + ) + val newEjsFile = ejsFile.copyTo(out / in.relativize(ejsFile).toString) + val jsFile = changeExtensionTo(newEjsFile, ".js").writeText(preprocessContent) + newEjsFile.createFile().writeText(sourceFileContent) + jsFile + } + + val result = + ExternalCommand.run(s"$astGenCommand$executableArgs -t ts -o $out", out.toString()) + + val jsons = SourceFiles.determine(out.toString(), Set(".json")) + jsons.foreach { jsonPath => + val jsonFile = File(jsonPath) + val jsonContent = IOUtils.readEntireFile(jsonFile.path) + val json = ujson.read(jsonContent) + val fileName = json("fullName").str + val newFileName = fileName.patch(fileName.lastIndexOf(".js"), ".ejs", 3) + json("relativeName") = newFileName + json("fullName") = newFileName + jsonFile.writeText(json.toString()) + } + + tmpJsFiles.foreach(_.delete()) + result + end processEjsFiles + + private def ejsFiles(in: File, out: File): Try[Seq[String]] = + val files = SourceFiles.determine(in.pathAsString, Set(".ejs")) + if files.nonEmpty then processEjsFiles(in, out, files) + else Success(Seq.empty) + + private def vueFiles(in: File, out: File): Try[Seq[String]] = + val files = SourceFiles.determine(in.pathAsString, Set(".vue")) + if files.nonEmpty then + ExternalCommand.run(s"$astGenCommand$executableArgs -t vue -o $out", in.toString()) + else Success(Seq.empty) + + private def jsFiles(in: File, out: File): Try[Seq[String]] = + ExternalCommand.run(s"$astGenCommand$executableArgs -t ts -o $out", in.toString()) + + private def runAstGenNative(in: File, out: File): Try[Seq[String]] = + for + ejsResult <- ejsFiles(in, out) + vueResult <- vueFiles(in, out) + jsResult <- jsFiles(in, out) + yield jsResult ++ vueResult ++ ejsResult + + def execute(out: File): AstGenRunnerResult = + val in = File(config.inputPath) + logger.debug(s"Running astgen in '$in' ...") + runAstGenNative(in, out) match + case Success(result) => + val parsed = filterFiles(SourceFiles.determine(out.toString(), Set(".json")), out) + val skipped = skippedFiles(in, result.toList) + AstGenRunnerResult(parsed.map((in.toString(), _)), skipped.map((in.toString(), _))) + case Failure(f) => + logger.debug("\t- running astgen failed!", f) + AstGenRunnerResult() end AstGenRunner diff --git a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/utils/PackageJsonParser.scala b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/utils/PackageJsonParser.scala index c6d83559..4befda43 100644 --- a/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/utils/PackageJsonParser.scala +++ b/platform/frontends/jssrc2cpg/src/main/scala/io/appthreat/jssrc2cpg/utils/PackageJsonParser.scala @@ -13,81 +13,81 @@ import scala.util.Failure import scala.util.Success object PackageJsonParser: - private val logger = LoggerFactory.getLogger(PackageJsonParser.getClass) + private val logger = LoggerFactory.getLogger(PackageJsonParser.getClass) - val PackageJsonFilename = "package.json" - val PackageJsonLockFilename = "package-lock.json" + val PackageJsonFilename = "package.json" + val PackageJsonLockFilename = "package-lock.json" - private val ProjectDependencies = - Seq("dependencies", "devDependencies", "peerDependencies", "optionalDependencies") + private val ProjectDependencies = + Seq("dependencies", "devDependencies", "peerDependencies", "optionalDependencies") - private val cachedDependencies: TrieMap[Path, Map[String, String]] = TrieMap.empty + private val cachedDependencies: TrieMap[Path, Map[String, String]] = TrieMap.empty - def isValidProjectPackageJson(packageJsonPath: Path): Boolean = - if packageJsonPath.toString.endsWith(PackageJsonParser.PackageJsonFilename) then - val isNotEmpty = Try(IOUtils.readLinesInFile(packageJsonPath)) match - case Success(content) => - content.forall(l => StringUtils.isNotBlank(StringUtils.normalizeSpace(l))) - case Failure(_) => false - isNotEmpty && dependencies(packageJsonPath).nonEmpty - else - false + def isValidProjectPackageJson(packageJsonPath: Path): Boolean = + if packageJsonPath.toString.endsWith(PackageJsonParser.PackageJsonFilename) then + val isNotEmpty = Try(IOUtils.readLinesInFile(packageJsonPath)) match + case Success(content) => + content.forall(l => StringUtils.isNotBlank(StringUtils.normalizeSpace(l))) + case Failure(_) => false + isNotEmpty && dependencies(packageJsonPath).nonEmpty + else + false - def dependencies(packageJsonPath: Path): Map[String, String] = - cachedDependencies.getOrElseUpdate( - packageJsonPath, { - val depsPath = packageJsonPath - val lockDepsPath = packageJsonPath.resolveSibling(Paths.get(PackageJsonLockFilename)) + def dependencies(packageJsonPath: Path): Map[String, String] = + cachedDependencies.getOrElseUpdate( + packageJsonPath, { + val depsPath = packageJsonPath + val lockDepsPath = packageJsonPath.resolveSibling(Paths.get(PackageJsonLockFilename)) - val lockDeps = Try { - val content = IOUtils.readEntireFile(lockDepsPath) - val objectMapper = new ObjectMapper - val packageJson = objectMapper.readTree(content) + val lockDeps = Try { + val content = IOUtils.readEntireFile(lockDepsPath) + val objectMapper = new ObjectMapper + val packageJson = objectMapper.readTree(content) - var depToVersion = Map.empty[String, String] - val dependencyIt = Option(packageJson.get("dependencies")) - .map(_.fields().asScala) - .getOrElse(Iterator.empty) - dependencyIt.foreach { entry => - val depName = entry.getKey - val versionNode = entry.getValue.get("version") - if versionNode != null then - depToVersion = depToVersion.updated(depName, versionNode.asText()) - } - depToVersion - }.toOption + var depToVersion = Map.empty[String, String] + val dependencyIt = Option(packageJson.get("dependencies")) + .map(_.fields().asScala) + .getOrElse(Iterator.empty) + dependencyIt.foreach { entry => + val depName = entry.getKey + val versionNode = entry.getValue.get("version") + if versionNode != null then + depToVersion = depToVersion.updated(depName, versionNode.asText()) + } + depToVersion + }.toOption - // lazy val because we only evaluate this in case no package lock file is available. - lazy val deps = Try { - val content = IOUtils.readEntireFile(depsPath) - val objectMapper = new ObjectMapper - val packageJson = objectMapper.readTree(content) + // lazy val because we only evaluate this in case no package lock file is available. + lazy val deps = Try { + val content = IOUtils.readEntireFile(depsPath) + val objectMapper = new ObjectMapper + val packageJson = objectMapper.readTree(content) - var depToVersion = Map.empty[String, String] - ProjectDependencies - .foreach { dependency => - val dependencyIt = Option(packageJson.get(dependency)) - .map(_.fields().asScala) - .getOrElse(Iterator.empty) - dependencyIt.foreach { entry => - depToVersion = - depToVersion.updated(entry.getKey, entry.getValue.asText()) - } + var depToVersion = Map.empty[String, String] + ProjectDependencies + .foreach { dependency => + val dependencyIt = Option(packageJson.get(dependency)) + .map(_.fields().asScala) + .getOrElse(Iterator.empty) + dependencyIt.foreach { entry => + depToVersion = + depToVersion.updated(entry.getKey, entry.getValue.asText()) } - depToVersion - }.toOption + } + depToVersion + }.toOption - if lockDeps.isDefined && lockDeps.get.nonEmpty then - logger.debug(s"Loaded dependencies from '$lockDepsPath'.") - lockDeps.get - else if deps.isDefined && deps.get.nonEmpty then - logger.debug(s"Loaded dependencies from '$depsPath'.") - deps.get - else - logger.debug( - s"No project dependencies found in $PackageJsonFilename or $PackageJsonLockFilename at '${depsPath.getParent}'." - ) - Map.empty - } - ) + if lockDeps.isDefined && lockDeps.get.nonEmpty then + logger.debug(s"Loaded dependencies from '$lockDepsPath'.") + lockDeps.get + else if deps.isDefined && deps.get.nonEmpty then + logger.debug(s"Loaded dependencies from '$depsPath'.") + deps.get + else + logger.debug( + s"No project dependencies found in $PackageJsonFilename or $PackageJsonLockFilename at '${depsPath.getParent}'." + ) + Map.empty + } + ) end PackageJsonParser diff --git a/platform/frontends/php2atom/src/main/scala/io/appthreat/php2atom/Main.scala b/platform/frontends/php2atom/src/main/scala/io/appthreat/php2atom/Main.scala index 40f51f4f..aa080156 100644 --- a/platform/frontends/php2atom/src/main/scala/io/appthreat/php2atom/Main.scala +++ b/platform/frontends/php2atom/src/main/scala/io/appthreat/php2atom/Main.scala @@ -10,30 +10,30 @@ import scopt.OParser final case class Config(phpIni: Option[String] = None, phpParserBin: Option[String] = None) extends X2CpgConfig[Config] with TypeRecoveryParserConfig[Config]: - def withPhpIni(phpIni: String): Config = - copy(phpIni = Some(phpIni)).withInheritedFields(this) + def withPhpIni(phpIni: String): Config = + copy(phpIni = Some(phpIni)).withInheritedFields(this) - def withPhpParserBin(phpParserBin: String): Config = - copy(phpParserBin = Some(phpParserBin)).withInheritedFields(this) + def withPhpParserBin(phpParserBin: String): Config = + copy(phpParserBin = Some(phpParserBin)).withInheritedFields(this) object Frontend: - implicit val defaultConfig: Config = Config() + implicit val defaultConfig: Config = Config() - val cmdLineParser: OParser[Unit, Config] = - val builder = OParser.builder[Config] - import builder.* - OParser.sequence( - programName("php2atom"), - opt[String]("php-ini") - .action((x, c) => c.withPhpIni(x)) - .text("php.ini path used by php-parser. Defaults to php.ini shipped with Chen."), - opt[String]("php-parser-bin") - .action((x, c) => c.withPhpParserBin(x)) - .text("path to php-parser.phar binary."), - XTypeRecovery.parserOptions - ) + val cmdLineParser: OParser[Unit, Config] = + val builder = OParser.builder[Config] + import builder.* + OParser.sequence( + programName("php2atom"), + opt[String]("php-ini") + .action((x, c) => c.withPhpIni(x)) + .text("php.ini path used by php-parser. Defaults to php.ini shipped with Chen."), + opt[String]("php-parser-bin") + .action((x, c) => c.withPhpParserBin(x)) + .text("path to php-parser.phar binary."), + XTypeRecovery.parserOptions + ) object Main extends X2CpgMain(cmdLineParser, new Php2Atom()): - def run(config: Config, php2Cpg: Php2Atom): Unit = - php2Cpg.run(config) + def run(config: Config, php2Cpg: Php2Atom): Unit = + php2Cpg.run(config) diff --git a/platform/frontends/php2atom/src/main/scala/io/appthreat/php2atom/Php2Atom.scala b/platform/frontends/php2atom/src/main/scala/io/appthreat/php2atom/Php2Atom.scala index 758c6a7e..fc20530a 100644 --- a/platform/frontends/php2atom/src/main/scala/io/appthreat/php2atom/Php2Atom.scala +++ b/platform/frontends/php2atom/src/main/scala/io/appthreat/php2atom/Php2Atom.scala @@ -15,56 +15,56 @@ import scala.collection.mutable import scala.util.{Failure, Success, Try} class Php2Atom extends X2CpgFrontend[Config]: - private val logger = LoggerFactory.getLogger(this.getClass) + private val logger = LoggerFactory.getLogger(this.getClass) - private def isPhpVersionSupported: Boolean = - val result = ExternalCommand.run("php --version", ".") - result match - case Success(listString) => - true - case Failure(exception) => - logger.debug(s"Failed to run php --version: ${exception.getMessage}") - false + private def isPhpVersionSupported: Boolean = + val result = ExternalCommand.run("php --version", ".") + result match + case Success(listString) => + true + case Failure(exception) => + logger.debug(s"Failed to run php --version: ${exception.getMessage}") + false - override def createCpg(config: Config): Try[Cpg] = - val errorMessages = mutable.ListBuffer[String]() + override def createCpg(config: Config): Try[Cpg] = + val errorMessages = mutable.ListBuffer[String]() - val parser = PhpParser.getParser(config) + val parser = PhpParser.getParser(config) - if parser.isEmpty then - errorMessages.append("Could not initialize PhpParser") - if !isPhpVersionSupported then - errorMessages.append( - "PHP version not supported. Is PHP 7.1.0 or above installed and available on your path?" - ) + if parser.isEmpty then + errorMessages.append("Could not initialize PhpParser") + if !isPhpVersionSupported then + errorMessages.append( + "PHP version not supported. Is PHP 7.1.0 or above installed and available on your path?" + ) - if errorMessages.isEmpty then - withNewEmptyCpg(config.outputPath, config: Config) { (cpg, config) => - new MetaDataPass(cpg, Languages.PHP, config.inputPath).createAndApply() - new AstCreationPass(config, cpg, parser.get)( - config.schemaValidation - ).createAndApply() - new ConfigFileCreationPass(cpg).createAndApply() - new AstParentInfoPass(cpg).createAndApply() - new AnyTypePass(cpg).createAndApply() - TypeNodePass.withTypesFromCpg(cpg).createAndApply() - LocalCreationPass.allLocalCreationPasses(cpg).foreach(_.createAndApply()) - new ClosureRefPass(cpg).createAndApply() - } - else - val errorOutput = ( - "Skipping AST creation as php/php-parser could not be executed." :: - errorMessages.toList - ).mkString("\n- ") + if errorMessages.isEmpty then + withNewEmptyCpg(config.outputPath, config: Config) { (cpg, config) => + new MetaDataPass(cpg, Languages.PHP, config.inputPath).createAndApply() + new AstCreationPass(config, cpg, parser.get)( + config.schemaValidation + ).createAndApply() + new ConfigFileCreationPass(cpg).createAndApply() + new AstParentInfoPass(cpg).createAndApply() + new AnyTypePass(cpg).createAndApply() + TypeNodePass.withTypesFromCpg(cpg).createAndApply() + LocalCreationPass.allLocalCreationPasses(cpg).foreach(_.createAndApply()) + new ClosureRefPass(cpg).createAndApply() + } + else + val errorOutput = ( + "Skipping AST creation as php/php-parser could not be executed." :: + errorMessages.toList + ).mkString("\n- ") - logger.debug(errorOutput) + logger.debug(errorOutput) - Failure(new RuntimeException("php not found or version not supported")) - end if - end createCpg + Failure(new RuntimeException("php not found or version not supported")) + end if + end createCpg end Php2Atom object Php2Atom: - def postProcessingPasses(cpg: Cpg, config: Option[Config] = None): List[CpgPassBase] = - List(new PhpSetKnownTypesPass(cpg)) + def postProcessingPasses(cpg: Cpg, config: Option[Config] = None): List[CpgPassBase] = + List(new PhpSetKnownTypesPass(cpg)) diff --git a/platform/frontends/php2atom/src/main/scala/io/appthreat/php2atom/astcreation/AstCreator.scala b/platform/frontends/php2atom/src/main/scala/io/appthreat/php2atom/astcreation/AstCreator.scala index c3093650..79a66f9b 100644 --- a/platform/frontends/php2atom/src/main/scala/io/appthreat/php2atom/astcreation/AstCreator.scala +++ b/platform/frontends/php2atom/src/main/scala/io/appthreat/php2atom/astcreation/AstCreator.scala @@ -21,2027 +21,2028 @@ class AstCreator(filename: String, phpAst: PhpFile)(implicit withSchemaValidatio extends AstCreatorBase(filename) with AstNodeBuilder[PhpNode, AstCreator]: - private val logger = LoggerFactory.getLogger(AstCreator.getClass) - private val scope = new Scope()(() => nextClosureName()) - private val tmpKeyPool = new IntervalKeyPool(first = 0, last = Long.MaxValue) - private val globalNamespace = globalNamespaceBlock() - - private def getNewTmpName(prefix: String = "tmp"): String = - s"$prefix${tmpKeyPool.next.toString}" - - override def createAst(): BatchedUpdate.DiffGraphBuilder = - val ast = astForPhpFile(phpAst) - storeInDiffGraph(ast, diffGraph) - diffGraph - - private def flattenGlobalNamespaceStmt(stmt: PhpStmt): List[PhpStmt] = - stmt match - case namespace: PhpNamespaceStmt if namespace.name.isEmpty => - namespace.stmts - - case _ => stmt :: Nil - - private def globalTypeDeclNode(file: PhpFile, globalNamespace: NewNamespaceBlock): NewTypeDecl = - typeDeclNode( - file, - globalNamespace.name, - globalNamespace.fullName, - filename, - globalNamespace.code, - NodeTypes.NAMESPACE_BLOCK, - globalNamespace.fullName - ) - - private def globalMethodDeclStmt(file: PhpFile, bodyStmts: List[PhpStmt]): PhpMethodDecl = - val modifiersList = List(ModifierTypes.VIRTUAL, ModifierTypes.PUBLIC, ModifierTypes.STATIC) - PhpMethodDecl( - name = PhpNameExpr(NamespaceTraversal.globalNamespaceName, file.attributes), - params = Nil, - modifiers = modifiersList, - returnType = None, - stmts = bodyStmts, - returnByRef = false, - namespacedName = None, - isClassMethod = false, - attributes = file.attributes - ) - - private def astForPhpFile(file: PhpFile): Ast = - scope.pushNewScope(globalNamespace) + private val logger = LoggerFactory.getLogger(AstCreator.getClass) + private val scope = new Scope()(() => nextClosureName()) + private val tmpKeyPool = new IntervalKeyPool(first = 0, last = Long.MaxValue) + private val globalNamespace = globalNamespaceBlock() + + private def getNewTmpName(prefix: String = "tmp"): String = + s"$prefix${tmpKeyPool.next.toString}" + + override def createAst(): BatchedUpdate.DiffGraphBuilder = + val ast = astForPhpFile(phpAst) + storeInDiffGraph(ast, diffGraph) + diffGraph + + private def flattenGlobalNamespaceStmt(stmt: PhpStmt): List[PhpStmt] = + stmt match + case namespace: PhpNamespaceStmt if namespace.name.isEmpty => + namespace.stmts + + case _ => stmt :: Nil + + private def globalTypeDeclNode(file: PhpFile, globalNamespace: NewNamespaceBlock): NewTypeDecl = + typeDeclNode( + file, + globalNamespace.name, + globalNamespace.fullName, + filename, + globalNamespace.code, + NodeTypes.NAMESPACE_BLOCK, + globalNamespace.fullName + ) + + private def globalMethodDeclStmt(file: PhpFile, bodyStmts: List[PhpStmt]): PhpMethodDecl = + val modifiersList = List(ModifierTypes.VIRTUAL, ModifierTypes.PUBLIC, ModifierTypes.STATIC) + PhpMethodDecl( + name = PhpNameExpr(NamespaceTraversal.globalNamespaceName, file.attributes), + params = Nil, + modifiers = modifiersList, + returnType = None, + stmts = bodyStmts, + returnByRef = false, + namespacedName = None, + isClassMethod = false, + attributes = file.attributes + ) - val (globalDeclStmts, globalMethodStmts) = - file.children.flatMap(flattenGlobalNamespaceStmt).partition( - _.isInstanceOf[PhpConstStmt] - ) + private def astForPhpFile(file: PhpFile): Ast = + scope.pushNewScope(globalNamespace) - val globalMethodStmt = globalMethodDeclStmt(file, globalMethodStmts) - - val globalTypeDeclStmt = PhpClassLikeStmt( - name = Some(PhpNameExpr(globalNamespace.name, file.attributes)), - modifiers = Nil, - extendsNames = Nil, - implementedInterfaces = Nil, - stmts = globalDeclStmts.appended(globalMethodStmt), - classLikeType = ClassLikeTypes.Class, - scalarType = None, - hasConstructor = false, - attributes = file.attributes + val (globalDeclStmts, globalMethodStmts) = + file.children.flatMap(flattenGlobalNamespaceStmt).partition( + _.isInstanceOf[PhpConstStmt] ) - val globalTypeDeclAst = astForClassLikeStmt(globalTypeDeclStmt) - - scope.popScope() // globalNamespace - - Ast(globalNamespace).withChild(globalTypeDeclAst) - end astForPhpFile - - private def astsForStmt(stmt: PhpStmt): List[Ast] = - stmt match - case echoStmt: PhpEchoStmt => astForEchoStmt(echoStmt) :: Nil - case methodDecl: PhpMethodDecl => astForMethodDecl(methodDecl) :: Nil - case expr: PhpExpr => astForExpr(expr) :: Nil - case breakStmt: PhpBreakStmt => astForBreakStmt(breakStmt) :: Nil - case contStmt: PhpContinueStmt => astForContinueStmt(contStmt) :: Nil - case whileStmt: PhpWhileStmt => astForWhileStmt(whileStmt) :: Nil - case doStmt: PhpDoStmt => astForDoStmt(doStmt) :: Nil - case forStmt: PhpForStmt => astForForStmt(forStmt) :: Nil - case ifStmt: PhpIfStmt => astForIfStmt(ifStmt) :: Nil - case switchStmt: PhpSwitchStmt => astForSwitchStmt(switchStmt) :: Nil - case tryStmt: PhpTryStmt => astForTryStmt(tryStmt) :: Nil - case returnStmt: PhpReturnStmt => astForReturnStmt(returnStmt) :: Nil - case classLikeStmt: PhpClassLikeStmt => astForClassLikeStmt(classLikeStmt) :: Nil - case gotoStmt: PhpGotoStmt => astForGotoStmt(gotoStmt) :: Nil - case labelStmt: PhpLabelStmt => astForLabelStmt(labelStmt) :: Nil - case namespace: PhpNamespaceStmt => astForNamespaceStmt(namespace) :: Nil - case declareStmt: PhpDeclareStmt => astForDeclareStmt(declareStmt) :: Nil - case _: NopStmt => Nil // TODO This'll need to be updated when comments are added. - case haltStmt: PhpHaltCompilerStmt => astForHaltCompilerStmt(haltStmt) :: Nil - case unsetStmt: PhpUnsetStmt => astForUnsetStmt(unsetStmt) :: Nil - case globalStmt: PhpGlobalStmt => astForGlobalStmt(globalStmt) :: Nil - case useStmt: PhpUseStmt => astForUseStmt(useStmt) :: Nil - case groupUseStmt: PhpGroupUseStmt => astForGroupUseStmt(groupUseStmt) :: Nil - case foreachStmt: PhpForeachStmt => astForForeachStmt(foreachStmt) :: Nil - case traitUseStmt: PhpTraitUseStmt => astforTraitUseStmt(traitUseStmt) :: Nil - case enumCase: PhpEnumCaseStmt => astForEnumCase(enumCase) :: Nil - case staticStmt: PhpStaticStmt => astsForStaticStmt(staticStmt) - case unhandled => - logger.debug(s"Unhandled stmt $unhandled in $filename") - ??? - - private def astForEchoStmt(echoStmt: PhpEchoStmt): Ast = - val args = echoStmt.exprs.map(astForExpr) - val code = s"echo ${args.map(_.rootCodeOrEmpty).mkString(",")}" - val callNode = - newOperatorCallNode("echo", code, line = line(echoStmt), column = column(echoStmt)) - callAst(callNode, args) - - private def thisParamAstForMethod(originNode: PhpNode): Ast = - val typeFullName = scope.getEnclosingTypeDeclTypeFullName.getOrElse(TypeConstants.Any) - - val thisNode = parameterInNode( - originNode, - name = NameConstants.This, - code = NameConstants.This, - index = 0, - isVariadic = false, - evaluationStrategy = EvaluationStrategies.BY_SHARING, - typeFullName = typeFullName - ).dynamicTypeHintFullName(typeFullName :: Nil) - // TODO Add dynamicTypeHintFullName to parameterInNode param list - - scope.addToScope(NameConstants.This, thisNode) - - Ast(thisNode) - - private def thisIdentifier(lineNumber: Option[Integer]): NewIdentifier = - val typ = scope.getEnclosingTypeDeclTypeName - newIdentifierNode(NameConstants.This, typ.getOrElse("ANY"), typ.toList, lineNumber) - .code(s"$$${NameConstants.This}") - - private def setParamIndices(asts: Seq[Ast]): Seq[Ast] = - asts.map(_.root).zipWithIndex.foreach { - case (Some(root: NewMethodParameterIn), idx) => - root.index(idx + 1) - - case (root, _) => - logger.debug(s"Trying to set index for unsupported node $root") - } - - asts + val globalMethodStmt = globalMethodDeclStmt(file, globalMethodStmts) + + val globalTypeDeclStmt = PhpClassLikeStmt( + name = Some(PhpNameExpr(globalNamespace.name, file.attributes)), + modifiers = Nil, + extendsNames = Nil, + implementedInterfaces = Nil, + stmts = globalDeclStmts.appended(globalMethodStmt), + classLikeType = ClassLikeTypes.Class, + scalarType = None, + hasConstructor = false, + attributes = file.attributes + ) - private def composeMethodFullName(methodName: String, isStatic: Boolean): String = - if methodName == NamespaceTraversal.globalNamespaceName then - globalNamespace.fullName + val globalTypeDeclAst = astForClassLikeStmt(globalTypeDeclStmt) + + scope.popScope() // globalNamespace + + Ast(globalNamespace).withChild(globalTypeDeclAst) + end astForPhpFile + + private def astsForStmt(stmt: PhpStmt): List[Ast] = + stmt match + case echoStmt: PhpEchoStmt => astForEchoStmt(echoStmt) :: Nil + case methodDecl: PhpMethodDecl => astForMethodDecl(methodDecl) :: Nil + case expr: PhpExpr => astForExpr(expr) :: Nil + case breakStmt: PhpBreakStmt => astForBreakStmt(breakStmt) :: Nil + case contStmt: PhpContinueStmt => astForContinueStmt(contStmt) :: Nil + case whileStmt: PhpWhileStmt => astForWhileStmt(whileStmt) :: Nil + case doStmt: PhpDoStmt => astForDoStmt(doStmt) :: Nil + case forStmt: PhpForStmt => astForForStmt(forStmt) :: Nil + case ifStmt: PhpIfStmt => astForIfStmt(ifStmt) :: Nil + case switchStmt: PhpSwitchStmt => astForSwitchStmt(switchStmt) :: Nil + case tryStmt: PhpTryStmt => astForTryStmt(tryStmt) :: Nil + case returnStmt: PhpReturnStmt => astForReturnStmt(returnStmt) :: Nil + case classLikeStmt: PhpClassLikeStmt => astForClassLikeStmt(classLikeStmt) :: Nil + case gotoStmt: PhpGotoStmt => astForGotoStmt(gotoStmt) :: Nil + case labelStmt: PhpLabelStmt => astForLabelStmt(labelStmt) :: Nil + case namespace: PhpNamespaceStmt => astForNamespaceStmt(namespace) :: Nil + case declareStmt: PhpDeclareStmt => astForDeclareStmt(declareStmt) :: Nil + case _: NopStmt => Nil // TODO This'll need to be updated when comments are added. + case haltStmt: PhpHaltCompilerStmt => astForHaltCompilerStmt(haltStmt) :: Nil + case unsetStmt: PhpUnsetStmt => astForUnsetStmt(unsetStmt) :: Nil + case globalStmt: PhpGlobalStmt => astForGlobalStmt(globalStmt) :: Nil + case useStmt: PhpUseStmt => astForUseStmt(useStmt) :: Nil + case groupUseStmt: PhpGroupUseStmt => astForGroupUseStmt(groupUseStmt) :: Nil + case foreachStmt: PhpForeachStmt => astForForeachStmt(foreachStmt) :: Nil + case traitUseStmt: PhpTraitUseStmt => astforTraitUseStmt(traitUseStmt) :: Nil + case enumCase: PhpEnumCaseStmt => astForEnumCase(enumCase) :: Nil + case staticStmt: PhpStaticStmt => astsForStaticStmt(staticStmt) + case unhandled => + logger.debug(s"Unhandled stmt $unhandled in $filename") + ??? + + private def astForEchoStmt(echoStmt: PhpEchoStmt): Ast = + val args = echoStmt.exprs.map(astForExpr) + val code = s"echo ${args.map(_.rootCodeOrEmpty).mkString(",")}" + val callNode = + newOperatorCallNode("echo", code, line = line(echoStmt), column = column(echoStmt)) + callAst(callNode, args) + + private def thisParamAstForMethod(originNode: PhpNode): Ast = + val typeFullName = scope.getEnclosingTypeDeclTypeFullName.getOrElse(TypeConstants.Any) + + val thisNode = parameterInNode( + originNode, + name = NameConstants.This, + code = NameConstants.This, + index = 0, + isVariadic = false, + evaluationStrategy = EvaluationStrategies.BY_SHARING, + typeFullName = typeFullName + ).dynamicTypeHintFullName(typeFullName :: Nil) + // TODO Add dynamicTypeHintFullName to parameterInNode param list + + scope.addToScope(NameConstants.This, thisNode) + + Ast(thisNode) + + private def thisIdentifier(lineNumber: Option[Integer]): NewIdentifier = + val typ = scope.getEnclosingTypeDeclTypeName + newIdentifierNode(NameConstants.This, typ.getOrElse("ANY"), typ.toList, lineNumber) + .code(s"$$${NameConstants.This}") + + private def setParamIndices(asts: Seq[Ast]): Seq[Ast] = + asts.map(_.root).zipWithIndex.foreach { + case (Some(root: NewMethodParameterIn), idx) => + root.index(idx + 1) + + case (root, _) => + logger.debug(s"Trying to set index for unsupported node $root") + } + + asts + + private def composeMethodFullName(methodName: String, isStatic: Boolean): String = + if methodName == NamespaceTraversal.globalNamespaceName then + globalNamespace.fullName + else + val className = getTypeDeclPrefix + val methodDelimiter = + if isStatic then StaticMethodDelimiter else InstanceMethodDelimiter + + val nameWithClass = List(className, Some(methodName)).flatten.mkString(methodDelimiter) + + prependNamespacePrefix(nameWithClass) + + private def astForMethodDecl( + decl: PhpMethodDecl, + bodyPrefixAsts: List[Ast] = Nil, + fullNameOverride: Option[String] = None, + isConstructor: Boolean = false + ): Ast = + val isStatic = decl.modifiers.contains(ModifierTypes.STATIC) + val thisParam = if decl.isClassMethod && !isStatic then + Option(thisParamAstForMethod(decl)) + else + None + + val methodName = decl.name.name + val fullName = fullNameOverride.getOrElse(composeMethodFullName(methodName, isStatic)) + + val signature = s"$UnresolvedSignature(${decl.params.size})" + + val parameters = thisParam.toList ++ decl.params.zipWithIndex.map { case (param, idx) => + astForParam(param, idx + 1) + } + + val constructorModifier = Option.when(isConstructor)(ModifierTypes.CONSTRUCTOR) + val defaultAccessModifier = + Option.unless(containsAccessModifier(decl.modifiers))(ModifierTypes.PUBLIC) + + val allModifiers = constructorModifier ++: defaultAccessModifier ++: decl.modifiers + val modifiers = allModifiers.map(newModifierNode) + val excludedModifiers = Set("MODULE", "LAMBDA") + val modifierString = decl.modifiers.filterNot(excludedModifiers.contains) match + case Nil => "" + case mods => s"${mods.mkString(" ")} " + val methodCode = + s"${modifierString}function $methodName(${parameters.map(_.rootCodeOrEmpty).mkString(",")})" + + val method = methodNode(decl, methodName, methodCode, fullName, Some(signature), filename) + + scope.pushNewScope(method) + + val returnType = decl.returnType.map(_.name).getOrElse(TypeConstants.Any) + + val methodBodyStmts = bodyPrefixAsts ++ decl.stmts.flatMap(astsForStmt) + val methodReturn = newMethodReturnNode(returnType, line = line(decl), column = column(decl)) + + val methodBody = blockAst(blockNode(decl), methodBodyStmts) + + scope.popScope() + methodAstWithAnnotations(method, parameters, methodBody, methodReturn, modifiers) + end astForMethodDecl + + private def stmtBodyBlockAst(stmt: PhpStmtWithBody): Ast = + val bodyBlock = blockNode(stmt) + val bodyStmtAsts = stmt.stmts.flatMap(astsForStmt) + Ast(bodyBlock).withChildren(bodyStmtAsts) + + private def astForParam(param: PhpParam, index: Int): Ast = + val evaluationStrategy = + if param.byRef then + EvaluationStrategies.BY_REFERENCE else - val className = getTypeDeclPrefix - val methodDelimiter = - if isStatic then StaticMethodDelimiter else InstanceMethodDelimiter - - val nameWithClass = List(className, Some(methodName)).flatten.mkString(methodDelimiter) - - prependNamespacePrefix(nameWithClass) - - private def astForMethodDecl( - decl: PhpMethodDecl, - bodyPrefixAsts: List[Ast] = Nil, - fullNameOverride: Option[String] = None, - isConstructor: Boolean = false - ): Ast = - val isStatic = decl.modifiers.contains(ModifierTypes.STATIC) - val thisParam = if decl.isClassMethod && !isStatic then - Option(thisParamAstForMethod(decl)) - else - None - - val methodName = decl.name.name - val fullName = fullNameOverride.getOrElse(composeMethodFullName(methodName, isStatic)) + EvaluationStrategies.BY_VALUE + + val typeFullName = param.paramType.map(_.name).getOrElse(TypeConstants.Any) + + val byRefCodePrefix = if param.byRef then "&" else "" + val code = s"$byRefCodePrefix$$${param.name}" + val paramNode = parameterInNode( + param, + param.name, + code, + index, + param.isVariadic, + evaluationStrategy, + typeFullName + ) - val signature = s"$UnresolvedSignature(${decl.params.size})" + scope.addToScope(param.name, paramNode) + + Ast(paramNode) + end astForParam + + private def astForExpr(expr: PhpExpr): Ast = + expr match + case funcCallExpr: PhpCallExpr => astForCall(funcCallExpr) + case variableExpr: PhpVariable => astForVariableExpr(variableExpr) + case nameExpr: PhpNameExpr => astForNameExpr(nameExpr) + case assignExpr: PhpAssignment => astForAssignment(assignExpr) + case scalarExpr: PhpScalar => astForScalar(scalarExpr) + case binaryOp: PhpBinaryOp => astForBinOp(binaryOp) + case unaryOp: PhpUnaryOp => astForUnaryOp(unaryOp) + case castExpr: PhpCast => astForCastExpr(castExpr) + case isSetExpr: PhpIsset => astForIsSetExpr(isSetExpr) + case printExpr: PhpPrint => astForPrintExpr(printExpr) + case ternaryOp: PhpTernaryOp => astForTernaryOp(ternaryOp) + case throwExpr: PhpThrowExpr => astForThrow(throwExpr) + case cloneExpr: PhpCloneExpr => astForClone(cloneExpr) + case emptyExpr: PhpEmptyExpr => astForEmpty(emptyExpr) + case evalExpr: PhpEvalExpr => astForEval(evalExpr) + case exitExpr: PhpExitExpr => astForExit(exitExpr) + case arrayExpr: PhpArrayExpr => astForArrayExpr(arrayExpr) + case listExpr: PhpListExpr => astForListExpr(listExpr) + case newExpr: PhpNewExpr => astForNewExpr(newExpr) + case matchExpr: PhpMatchExpr => astForMatchExpr(matchExpr) + case yieldExpr: PhpYieldExpr => astForYieldExpr(yieldExpr) + case closure: PhpClosureExpr => astForClosureExpr(closure) + case yieldFromExpr: PhpYieldFromExpr => astForYieldFromExpr(yieldFromExpr) + case classConstFetchExpr: PhpClassConstFetchExpr => + astForClassConstFetchExpr(classConstFetchExpr) + case constFetchExpr: PhpConstFetchExpr => astForConstFetchExpr(constFetchExpr) + case arrayDimFetchExpr: PhpArrayDimFetchExpr => + astForArrayDimFetchExpr(arrayDimFetchExpr) + case errorSuppressExpr: PhpErrorSuppressExpr => + astForErrorSuppressExpr(errorSuppressExpr) + case instanceOfExpr: PhpInstanceOfExpr => astForInstanceOfExpr(instanceOfExpr) + case propertyFetchExpr: PhpPropertyFetchExpr => + astForPropertyFetchExpr(propertyFetchExpr) + case includeExpr: PhpIncludeExpr => astForIncludeExpr(includeExpr) + case shellExecExpr: PhpShellExecExpr => astForShellExecExpr(shellExecExpr) + case null => + logger.debug("expr was null") + ??? + case other => throw new NotImplementedError( + s"unexpected expression '$other' of type ${other.getClass}" + ) - val parameters = thisParam.toList ++ decl.params.zipWithIndex.map { case (param, idx) => - astForParam(param, idx + 1) - } + private def intToLiteralAst(num: Int): Ast = + Ast(NewLiteral().code(num.toString).typeFullName(TypeConstants.Int)) - val constructorModifier = Option.when(isConstructor)(ModifierTypes.CONSTRUCTOR) - val defaultAccessModifier = - Option.unless(containsAccessModifier(decl.modifiers))(ModifierTypes.PUBLIC) + private def astForBreakStmt(breakStmt: PhpBreakStmt): Ast = + val code = breakStmt.num.map(num => s"break($num)").getOrElse("break") + val breakNode = controlStructureNode(breakStmt, ControlStructureTypes.BREAK, code) - val allModifiers = constructorModifier ++: defaultAccessModifier ++: decl.modifiers - val modifiers = allModifiers.map(newModifierNode) - val excludedModifiers = Set("MODULE", "LAMBDA") - val modifierString = decl.modifiers.filterNot(excludedModifiers.contains) match - case Nil => "" - case mods => s"${mods.mkString(" ")} " - val methodCode = - s"${modifierString}function $methodName(${parameters.map(_.rootCodeOrEmpty).mkString(",")})" + val argument = breakStmt.num.map(intToLiteralAst) - val method = methodNode(decl, methodName, methodCode, fullName, Some(signature), filename) + controlStructureAst(breakNode, None, argument.toList) - scope.pushNewScope(method) + private def astForContinueStmt(continueStmt: PhpContinueStmt): Ast = + val code = continueStmt.num.map(num => s"continue($num)").getOrElse("continue") + val continueNode = controlStructureNode(continueStmt, ControlStructureTypes.CONTINUE, code) - val returnType = decl.returnType.map(_.name).getOrElse(TypeConstants.Any) + val argument = continueStmt.num.map(intToLiteralAst) - val methodBodyStmts = bodyPrefixAsts ++ decl.stmts.flatMap(astsForStmt) - val methodReturn = newMethodReturnNode(returnType, line = line(decl), column = column(decl)) + controlStructureAst(continueNode, None, argument.toList) - val methodBody = blockAst(blockNode(decl), methodBodyStmts) + private def astForWhileStmt(whileStmt: PhpWhileStmt): Ast = + val condition = astForExpr(whileStmt.cond) + val lineNumber = line(whileStmt) + val code = s"while (${condition.rootCodeOrEmpty})" + val body = stmtBodyBlockAst(whileStmt) - scope.popScope() - methodAstWithAnnotations(method, parameters, methodBody, methodReturn, modifiers) - end astForMethodDecl + whileAst( + Option(condition), + List(body), + Option(code), + lineNumber, + columnNumber = column(whileStmt) + ) - private def stmtBodyBlockAst(stmt: PhpStmtWithBody): Ast = - val bodyBlock = blockNode(stmt) - val bodyStmtAsts = stmt.stmts.flatMap(astsForStmt) - Ast(bodyBlock).withChildren(bodyStmtAsts) + private def astForDoStmt(doStmt: PhpDoStmt): Ast = + val condition = astForExpr(doStmt.cond) + val lineNumber = line(doStmt) + val code = s"do {...} while (${condition.rootCodeOrEmpty})" + val body = stmtBodyBlockAst(doStmt) + + doWhileAst( + Option(condition), + List(body), + Option(code), + lineNumber, + columnNumber = column(doStmt) + ) - private def astForParam(param: PhpParam, index: Int): Ast = - val evaluationStrategy = - if param.byRef then - EvaluationStrategies.BY_REFERENCE - else - EvaluationStrategies.BY_VALUE + private def astForForStmt(stmt: PhpForStmt): Ast = + val lineNumber = line(stmt) - val typeFullName = param.paramType.map(_.name).getOrElse(TypeConstants.Any) + val initAsts = stmt.inits.map(astForExpr) + val conditionAsts = stmt.conditions.map(astForExpr) + val loopExprAsts = stmt.loopExprs.map(astForExpr) - val byRefCodePrefix = if param.byRef then "&" else "" - val code = s"$byRefCodePrefix$$${param.name}" - val paramNode = parameterInNode( - param, - param.name, - code, - index, - param.isVariadic, - evaluationStrategy, - typeFullName - ) + val bodyAst = stmtBodyBlockAst(stmt) - scope.addToScope(param.name, paramNode) - - Ast(paramNode) - end astForParam - - private def astForExpr(expr: PhpExpr): Ast = - expr match - case funcCallExpr: PhpCallExpr => astForCall(funcCallExpr) - case variableExpr: PhpVariable => astForVariableExpr(variableExpr) - case nameExpr: PhpNameExpr => astForNameExpr(nameExpr) - case assignExpr: PhpAssignment => astForAssignment(assignExpr) - case scalarExpr: PhpScalar => astForScalar(scalarExpr) - case binaryOp: PhpBinaryOp => astForBinOp(binaryOp) - case unaryOp: PhpUnaryOp => astForUnaryOp(unaryOp) - case castExpr: PhpCast => astForCastExpr(castExpr) - case isSetExpr: PhpIsset => astForIsSetExpr(isSetExpr) - case printExpr: PhpPrint => astForPrintExpr(printExpr) - case ternaryOp: PhpTernaryOp => astForTernaryOp(ternaryOp) - case throwExpr: PhpThrowExpr => astForThrow(throwExpr) - case cloneExpr: PhpCloneExpr => astForClone(cloneExpr) - case emptyExpr: PhpEmptyExpr => astForEmpty(emptyExpr) - case evalExpr: PhpEvalExpr => astForEval(evalExpr) - case exitExpr: PhpExitExpr => astForExit(exitExpr) - case arrayExpr: PhpArrayExpr => astForArrayExpr(arrayExpr) - case listExpr: PhpListExpr => astForListExpr(listExpr) - case newExpr: PhpNewExpr => astForNewExpr(newExpr) - case matchExpr: PhpMatchExpr => astForMatchExpr(matchExpr) - case yieldExpr: PhpYieldExpr => astForYieldExpr(yieldExpr) - case closure: PhpClosureExpr => astForClosureExpr(closure) - case yieldFromExpr: PhpYieldFromExpr => astForYieldFromExpr(yieldFromExpr) - case classConstFetchExpr: PhpClassConstFetchExpr => - astForClassConstFetchExpr(classConstFetchExpr) - case constFetchExpr: PhpConstFetchExpr => astForConstFetchExpr(constFetchExpr) - case arrayDimFetchExpr: PhpArrayDimFetchExpr => - astForArrayDimFetchExpr(arrayDimFetchExpr) - case errorSuppressExpr: PhpErrorSuppressExpr => - astForErrorSuppressExpr(errorSuppressExpr) - case instanceOfExpr: PhpInstanceOfExpr => astForInstanceOfExpr(instanceOfExpr) - case propertyFetchExpr: PhpPropertyFetchExpr => - astForPropertyFetchExpr(propertyFetchExpr) - case includeExpr: PhpIncludeExpr => astForIncludeExpr(includeExpr) - case shellExecExpr: PhpShellExecExpr => astForShellExecExpr(shellExecExpr) - case null => - logger.debug("expr was null") - ??? - case other => throw new NotImplementedError( - s"unexpected expression '$other' of type ${other.getClass}" - ) - - private def intToLiteralAst(num: Int): Ast = - Ast(NewLiteral().code(num.toString).typeFullName(TypeConstants.Int)) + val initCode = initAsts.map(_.rootCodeOrEmpty).mkString(",") + val conditionCode = conditionAsts.map(_.rootCodeOrEmpty).mkString(",") + val loopExprCode = loopExprAsts.map(_.rootCodeOrEmpty).mkString(",") + val forCode = s"for ($initCode;$conditionCode;$loopExprCode)" - private def astForBreakStmt(breakStmt: PhpBreakStmt): Ast = - val code = breakStmt.num.map(num => s"break($num)").getOrElse("break") - val breakNode = controlStructureNode(breakStmt, ControlStructureTypes.BREAK, code) + val forNode = controlStructureNode(stmt, ControlStructureTypes.FOR, forCode) + forAst(forNode, Nil, initAsts, conditionAsts, loopExprAsts, bodyAst) - val argument = breakStmt.num.map(intToLiteralAst) + private def astForIfStmt(ifStmt: PhpIfStmt): Ast = + val condition = astForExpr(ifStmt.cond) - controlStructureAst(breakNode, None, argument.toList) + val thenAst = stmtBodyBlockAst(ifStmt) - private def astForContinueStmt(continueStmt: PhpContinueStmt): Ast = - val code = continueStmt.num.map(num => s"continue($num)").getOrElse("continue") - val continueNode = controlStructureNode(continueStmt, ControlStructureTypes.CONTINUE, code) + val elseAst = ifStmt.elseIfs match + case Nil => ifStmt.elseStmt.map(els => stmtBodyBlockAst(els)).toList - val argument = continueStmt.num.map(intToLiteralAst) + case elseIf :: rest => + val newIfStmt = + PhpIfStmt(elseIf.cond, elseIf.stmts, rest, ifStmt.elseStmt, elseIf.attributes) + val wrappingBlock = blockNode(elseIf) + val wrappedAst = Ast(wrappingBlock).withChild(astForIfStmt(newIfStmt)) :: Nil + wrappedAst - controlStructureAst(continueNode, None, argument.toList) + val conditionCode = condition.rootCodeOrEmpty + val ifNode = controlStructureNode(ifStmt, ControlStructureTypes.IF, s"if ($conditionCode)") - private def astForWhileStmt(whileStmt: PhpWhileStmt): Ast = - val condition = astForExpr(whileStmt.cond) - val lineNumber = line(whileStmt) - val code = s"while (${condition.rootCodeOrEmpty})" - val body = stmtBodyBlockAst(whileStmt) + controlStructureAst(ifNode, Option(condition), thenAst :: elseAst) + end astForIfStmt - whileAst( - Option(condition), - List(body), - Option(code), - lineNumber, - columnNumber = column(whileStmt) - ) + private def astForSwitchStmt(stmt: PhpSwitchStmt): Ast = + val conditionAst = astForExpr(stmt.condition) - private def astForDoStmt(doStmt: PhpDoStmt): Ast = - val condition = astForExpr(doStmt.cond) - val lineNumber = line(doStmt) - val code = s"do {...} while (${condition.rootCodeOrEmpty})" - val body = stmtBodyBlockAst(doStmt) - - doWhileAst( - Option(condition), - List(body), - Option(code), - lineNumber, - columnNumber = column(doStmt) + val switchNode = + controlStructureNode( + stmt, + ControlStructureTypes.SWITCH, + s"switch (${conditionAst.rootCodeOrEmpty})" ) - private def astForForStmt(stmt: PhpForStmt): Ast = - val lineNumber = line(stmt) - - val initAsts = stmt.inits.map(astForExpr) - val conditionAsts = stmt.conditions.map(astForExpr) - val loopExprAsts = stmt.loopExprs.map(astForExpr) - - val bodyAst = stmtBodyBlockAst(stmt) + val switchBodyBlock = blockNode(stmt) + val entryAsts = stmt.cases.flatMap(astsForSwitchCase) + val switchBody = Ast(switchBodyBlock).withChildren(entryAsts) - val initCode = initAsts.map(_.rootCodeOrEmpty).mkString(",") - val conditionCode = conditionAsts.map(_.rootCodeOrEmpty).mkString(",") - val loopExprCode = loopExprAsts.map(_.rootCodeOrEmpty).mkString(",") - val forCode = s"for ($initCode;$conditionCode;$loopExprCode)" + controlStructureAst(switchNode, Option(conditionAst), switchBody :: Nil) - val forNode = controlStructureNode(stmt, ControlStructureTypes.FOR, forCode) - forAst(forNode, Nil, initAsts, conditionAsts, loopExprAsts, bodyAst) + private def astForTryStmt(stmt: PhpTryStmt): Ast = + val tryBody = stmtBodyBlockAst(stmt) + val catches = stmt.catches.map(astForCatchStmt) + val finallyBody = stmt.finallyStmt.map(fin => stmtBodyBlockAst(fin)) - private def astForIfStmt(ifStmt: PhpIfStmt): Ast = - val condition = astForExpr(ifStmt.cond) + val tryNode = controlStructureNode(stmt, ControlStructureTypes.TRY, "try { ... }") - val thenAst = stmtBodyBlockAst(ifStmt) + tryCatchAst(tryNode, tryBody, catches, finallyBody) - val elseAst = ifStmt.elseIfs match - case Nil => ifStmt.elseStmt.map(els => stmtBodyBlockAst(els)).toList + private def astForReturnStmt(stmt: PhpReturnStmt): Ast = + val maybeExprAst = stmt.expr.map(astForExpr) + val code = s"return ${maybeExprAst.map(_.rootCodeOrEmpty).getOrElse("")}" - case elseIf :: rest => - val newIfStmt = - PhpIfStmt(elseIf.cond, elseIf.stmts, rest, ifStmt.elseStmt, elseIf.attributes) - val wrappingBlock = blockNode(elseIf) - val wrappedAst = Ast(wrappingBlock).withChild(astForIfStmt(newIfStmt)) :: Nil - wrappedAst + val node = returnNode(stmt, code) - val conditionCode = condition.rootCodeOrEmpty - val ifNode = controlStructureNode(ifStmt, ControlStructureTypes.IF, s"if ($conditionCode)") + returnAst(node, maybeExprAst.toList) - controlStructureAst(ifNode, Option(condition), thenAst :: elseAst) - end astForIfStmt + private def astForClassLikeStmt(stmt: PhpClassLikeStmt): Ast = + stmt.name match + case None => astForAnonymousClass(stmt) + case Some(name) => astForNamedClass(stmt, name) - private def astForSwitchStmt(stmt: PhpSwitchStmt): Ast = - val conditionAst = astForExpr(stmt.condition) + private def astForGotoStmt(stmt: PhpGotoStmt): Ast = + val label = stmt.label.name + val code = s"goto $label" - val switchNode = - controlStructureNode( - stmt, - ControlStructureTypes.SWITCH, - s"switch (${conditionAst.rootCodeOrEmpty})" - ) - - val switchBodyBlock = blockNode(stmt) - val entryAsts = stmt.cases.flatMap(astsForSwitchCase) - val switchBody = Ast(switchBodyBlock).withChildren(entryAsts) - - controlStructureAst(switchNode, Option(conditionAst), switchBody :: Nil) - - private def astForTryStmt(stmt: PhpTryStmt): Ast = - val tryBody = stmtBodyBlockAst(stmt) - val catches = stmt.catches.map(astForCatchStmt) - val finallyBody = stmt.finallyStmt.map(fin => stmtBodyBlockAst(fin)) - - val tryNode = controlStructureNode(stmt, ControlStructureTypes.TRY, "try { ... }") - - tryCatchAst(tryNode, tryBody, catches, finallyBody) - - private def astForReturnStmt(stmt: PhpReturnStmt): Ast = - val maybeExprAst = stmt.expr.map(astForExpr) - val code = s"return ${maybeExprAst.map(_.rootCodeOrEmpty).getOrElse("")}" + val gotoNode = controlStructureNode(stmt, ControlStructureTypes.GOTO, code) - val node = returnNode(stmt, code) + val jumpLabel = NewJumpLabel() + .name(label) + .code(label) + .lineNumber(line(stmt)) + .columnNumber(column(stmt)) - returnAst(node, maybeExprAst.toList) + controlStructureAst(gotoNode, condition = None, children = Ast(jumpLabel) :: Nil) - private def astForClassLikeStmt(stmt: PhpClassLikeStmt): Ast = - stmt.name match - case None => astForAnonymousClass(stmt) - case Some(name) => astForNamedClass(stmt, name) + private def astForLabelStmt(stmt: PhpLabelStmt): Ast = + val label = stmt.label.name - private def astForGotoStmt(stmt: PhpGotoStmt): Ast = - val label = stmt.label.name - val code = s"goto $label" + val jumpTarget = NewJumpTarget() + .name(label) + .code(label) + .lineNumber(line(stmt)) + .columnNumber(column(stmt)) - val gotoNode = controlStructureNode(stmt, ControlStructureTypes.GOTO, code) + Ast(jumpTarget) - val jumpLabel = NewJumpLabel() - .name(label) - .code(label) - .lineNumber(line(stmt)) - .columnNumber(column(stmt)) + private def astForNamespaceStmt(stmt: PhpNamespaceStmt): Ast = + val name = stmt.name.map(_.name).getOrElse(NameConstants.Unknown) + val fullName = s"$filename:$name" - controlStructureAst(gotoNode, condition = None, children = Ast(jumpLabel) :: Nil) + val namespaceBlock = NewNamespaceBlock() + .name(name) + .fullName(fullName) + .lineNumber(line(stmt)) + .columnNumber(column(stmt)) - private def astForLabelStmt(stmt: PhpLabelStmt): Ast = - val label = stmt.label.name + scope.pushNewScope(namespaceBlock) + val bodyStmts = astsForClassLikeBody(stmt, stmt.stmts, createDefaultConstructor = false) + scope.popScope() - val jumpTarget = NewJumpTarget() - .name(label) - .code(label) - .lineNumber(line(stmt)) - .columnNumber(column(stmt)) + Ast(namespaceBlock).withChildren(bodyStmts) - Ast(jumpTarget) - - private def astForNamespaceStmt(stmt: PhpNamespaceStmt): Ast = - val name = stmt.name.map(_.name).getOrElse(NameConstants.Unknown) - val fullName = s"$filename:$name" - - val namespaceBlock = NewNamespaceBlock() - .name(name) - .fullName(fullName) - .lineNumber(line(stmt)) - .columnNumber(column(stmt)) - - scope.pushNewScope(namespaceBlock) - val bodyStmts = astsForClassLikeBody(stmt, stmt.stmts, createDefaultConstructor = false) - scope.popScope() - - Ast(namespaceBlock).withChildren(bodyStmts) - - private def astForDeclareStmt(stmt: PhpDeclareStmt): Ast = - val declareAssignAsts = stmt.declares.map(astForDeclareItem) - val declareCode = - s"${PhpOperators.declareFunc}(${declareAssignAsts.map(_.rootCodeOrEmpty).mkString(",")})" - val declareNode = - newOperatorCallNode( - PhpOperators.declareFunc, - declareCode, - line = line(stmt), - column = column(stmt) - ) - val declareAst = callAst(declareNode, declareAssignAsts) + private def astForDeclareStmt(stmt: PhpDeclareStmt): Ast = + val declareAssignAsts = stmt.declares.map(astForDeclareItem) + val declareCode = + s"${PhpOperators.declareFunc}(${declareAssignAsts.map(_.rootCodeOrEmpty).mkString(",")})" + val declareNode = + newOperatorCallNode( + PhpOperators.declareFunc, + declareCode, + line = line(stmt), + column = column(stmt) + ) + val declareAst = callAst(declareNode, declareAssignAsts) + + stmt.stmts match + case Some(stmtList) => + val stmtAsts = stmtList.flatMap(astsForStmt) + Ast(blockNode(stmt)) + .withChild(declareAst) + .withChildren(stmtAsts) + + case None => declareAst + end astForDeclareStmt + + private def astForDeclareItem(item: PhpDeclareItem): Ast = + val key = identifierNode(item, item.key.name, item.key.name, "ANY") + val value = astForExpr(item.value) + val code = s"${key.name}=${value.rootCodeOrEmpty}" + + val declareAssignment = newOperatorCallNode( + Operators.assignment, + code, + line = line(item), + column = column(item) + ) + callAst(declareAssignment, Ast(key) :: value :: Nil) + + private def astForHaltCompilerStmt(stmt: PhpHaltCompilerStmt): Ast = + val call = newOperatorCallNode( + NameConstants.HaltCompiler, + s"${NameConstants.HaltCompiler}()", + Some(TypeConstants.Void), + line(stmt), + column(stmt) + ) - stmt.stmts match - case Some(stmtList) => - val stmtAsts = stmtList.flatMap(astsForStmt) - Ast(blockNode(stmt)) - .withChild(declareAst) - .withChildren(stmtAsts) + Ast(call) + + private def astForUnsetStmt(stmt: PhpUnsetStmt): Ast = + val name = PhpOperators.unset + val args = stmt.vars.map(astForExpr) + val code = s"$name(${args.map(_.rootCodeOrEmpty).mkString(", ")})" + val callNode = newOperatorCallNode( + name, + code, + typeFullName = Some(TypeConstants.Void), + line = line(stmt), + column = column(stmt) + ) + .methodFullName(PhpOperators.unset) + callAst(callNode, args) - case None => declareAst - end astForDeclareStmt + private def astForGlobalStmt(stmt: PhpGlobalStmt): Ast = + // This isn't an accurater representation of what `global` does, but with things like `global $$x` being possible, + // it's very difficult to figure out correct scopes for global variables. - private def astForDeclareItem(item: PhpDeclareItem): Ast = - val key = identifierNode(item, item.key.name, item.key.name, "ANY") - val value = astForExpr(item.value) - val code = s"${key.name}=${value.rootCodeOrEmpty}" + val varsAsts = stmt.vars.map(astForExpr) + val code = s"${PhpOperators.global} ${varsAsts.map(_.rootCodeOrEmpty).mkString(", ")}" - val declareAssignment = newOperatorCallNode( - Operators.assignment, + val globalCallNode = + newOperatorCallNode( + PhpOperators.global, code, - line = line(item), - column = column(item) - ) - callAst(declareAssignment, Ast(key) :: value :: Nil) - - private def astForHaltCompilerStmt(stmt: PhpHaltCompilerStmt): Ast = - val call = newOperatorCallNode( - NameConstants.HaltCompiler, - s"${NameConstants.HaltCompiler}()", Some(TypeConstants.Void), line(stmt), column(stmt) ) - Ast(call) - - private def astForUnsetStmt(stmt: PhpUnsetStmt): Ast = - val name = PhpOperators.unset - val args = stmt.vars.map(astForExpr) - val code = s"$name(${args.map(_.rootCodeOrEmpty).mkString(", ")})" - val callNode = newOperatorCallNode( - name, - code, - typeFullName = Some(TypeConstants.Void), + callAst(globalCallNode, varsAsts) + + private def astForUseStmt(stmt: PhpUseStmt): Ast = + // TODO Use useType + scope to get better name info + val imports = stmt.uses.map(astForUseUse(_)) + wrapMultipleInBlock(imports, line(stmt), column(stmt)) + + private def astForGroupUseStmt(stmt: PhpGroupUseStmt): Ast = + // TODO Use useType + scope to get better name info + val groupPrefix = s"${stmt.prefix.name}\\" + val imports = stmt.uses.map(astForUseUse(_, groupPrefix)) + wrapMultipleInBlock(imports, line(stmt), column(stmt)) + + private def astForKeyValPair( + key: PhpExpr, + value: PhpExpr, + lineNo: Option[Integer], + colNo: Option[Integer] + ): Ast = + val keyAst = astForExpr(key) + val valueAst = astForExpr(value) + + val code = s"${keyAst.rootCodeOrEmpty} => ${valueAst.rootCodeOrEmpty}" + val callNode = + newOperatorCallNode(PhpOperators.doubleArrow, code, line = lineNo, column = colNo) + callAst(callNode, keyAst :: valueAst :: Nil) + + private def astForForeachStmt(stmt: PhpForeachStmt): Ast = + val iteratorAst = astForExpr(stmt.iterExpr) + val iterIdentifier = getTmpIdentifier(stmt, maybeTypeFullName = None, prefix = "iter_") + + val assignItemTargetAst = stmt.keyVar match + case Some(key) => astForKeyValPair(key, stmt.valueVar, line(stmt), column(stmt)) + case None => astForExpr(stmt.valueVar) + + // Initializer asts + // - Iterator assign + val iterValue = astForExpr(stmt.iterExpr) + val iteratorAssignAst = + simpleAssignAst(Ast(iterIdentifier), iterValue, line(stmt), column(stmt)) + + // - Assigned item assign + val itemInitAst = getItemAssignAstForForeach(stmt, assignItemTargetAst, iterIdentifier.copy) + + // Condition ast + val isNullName = PhpOperators.isNull + val valueAst = astForExpr(stmt.valueVar) + val isNullCode = s"$isNullName(${valueAst.rootCodeOrEmpty})" + val isNullCall = + newOperatorCallNode( + isNullName, + isNullCode, + Some(TypeConstants.Bool), + line(stmt), + column(stmt) + ) + .methodFullName(PhpOperators.isNull) + val notIsNull = + newOperatorCallNode( + Operators.logicalNot, + s"!$isNullCode", line = line(stmt), column = column(stmt) ) - .methodFullName(PhpOperators.unset) - callAst(callNode, args) - - private def astForGlobalStmt(stmt: PhpGlobalStmt): Ast = - // This isn't an accurater representation of what `global` does, but with things like `global $$x` being possible, - // it's very difficult to figure out correct scopes for global variables. - - val varsAsts = stmt.vars.map(astForExpr) - val code = s"${PhpOperators.global} ${varsAsts.map(_.rootCodeOrEmpty).mkString(", ")}" - - val globalCallNode = - newOperatorCallNode( - PhpOperators.global, - code, - Some(TypeConstants.Void), - line(stmt), - column(stmt) - ) - - callAst(globalCallNode, varsAsts) - - private def astForUseStmt(stmt: PhpUseStmt): Ast = - // TODO Use useType + scope to get better name info - val imports = stmt.uses.map(astForUseUse(_)) - wrapMultipleInBlock(imports, line(stmt), column(stmt)) - - private def astForGroupUseStmt(stmt: PhpGroupUseStmt): Ast = - // TODO Use useType + scope to get better name info - val groupPrefix = s"${stmt.prefix.name}\\" - val imports = stmt.uses.map(astForUseUse(_, groupPrefix)) - wrapMultipleInBlock(imports, line(stmt), column(stmt)) - - private def astForKeyValPair( - key: PhpExpr, - value: PhpExpr, - lineNo: Option[Integer], - colNo: Option[Integer] - ): Ast = - val keyAst = astForExpr(key) - val valueAst = astForExpr(value) - - val code = s"${keyAst.rootCodeOrEmpty} => ${valueAst.rootCodeOrEmpty}" - val callNode = - newOperatorCallNode(PhpOperators.doubleArrow, code, line = lineNo, column = colNo) - callAst(callNode, keyAst :: valueAst :: Nil) - - private def astForForeachStmt(stmt: PhpForeachStmt): Ast = - val iteratorAst = astForExpr(stmt.iterExpr) - val iterIdentifier = getTmpIdentifier(stmt, maybeTypeFullName = None, prefix = "iter_") - - val assignItemTargetAst = stmt.keyVar match - case Some(key) => astForKeyValPair(key, stmt.valueVar, line(stmt), column(stmt)) - case None => astForExpr(stmt.valueVar) - - // Initializer asts - // - Iterator assign - val iterValue = astForExpr(stmt.iterExpr) - val iteratorAssignAst = - simpleAssignAst(Ast(iterIdentifier), iterValue, line(stmt), column(stmt)) - - // - Assigned item assign - val itemInitAst = getItemAssignAstForForeach(stmt, assignItemTargetAst, iterIdentifier.copy) - - // Condition ast - val isNullName = PhpOperators.isNull - val valueAst = astForExpr(stmt.valueVar) - val isNullCode = s"$isNullName(${valueAst.rootCodeOrEmpty})" - val isNullCall = - newOperatorCallNode( - isNullName, - isNullCode, - Some(TypeConstants.Bool), - line(stmt), - column(stmt) - ) - .methodFullName(PhpOperators.isNull) - val notIsNull = - newOperatorCallNode( - Operators.logicalNot, - s"!$isNullCode", - line = line(stmt), - column = column(stmt) - ) - val isNullAst = callAst(isNullCall, valueAst :: Nil) - val conditionAst = callAst(notIsNull, isNullAst :: Nil) - - // Update asts - val nextIterIdent = Ast(iterIdentifier.copy) - val nextSignature = "void()" - val nextCallCode = s"${nextIterIdent.rootCodeOrEmpty}->next()" - val nextCallNode = callNode( + val isNullAst = callAst(isNullCall, valueAst :: Nil) + val conditionAst = callAst(notIsNull, isNullAst :: Nil) + + // Update asts + val nextIterIdent = Ast(iterIdentifier.copy) + val nextSignature = "void()" + val nextCallCode = s"${nextIterIdent.rootCodeOrEmpty}->next()" + val nextCallNode = callNode( + stmt, + nextCallCode, + "next", + "Iterator.next", + DispatchTypes.DYNAMIC_DISPATCH, + Some(nextSignature), + Some(TypeConstants.Any) + ) + val nextCallAst = callAst(nextCallNode, base = Option(nextIterIdent)) + val itemUpdateAst = itemInitAst.root match + case Some(initRoot: AstNodeNew) => itemInitAst.subTreeCopy(initRoot) + case _ => + logger.debug(s"Could not copy foreach init ast in $filename") + Ast() + + val bodyAst = stmtBodyBlockAst(stmt) + + val ampPrefix = if stmt.assignByRef then "&" else "" + val foreachCode = + s"foreach (${iteratorAst.rootCodeOrEmpty} as $ampPrefix${assignItemTargetAst.rootCodeOrEmpty})" + val foreachNode = controlStructureNode(stmt, ControlStructureTypes.FOR, foreachCode) + Ast(foreachNode) + .withChild(wrapMultipleInBlock( + iteratorAssignAst :: itemInitAst :: Nil, + line(stmt), + column(stmt) + )) + .withChild(conditionAst) + .withChild(wrapMultipleInBlock( + nextCallAst :: itemUpdateAst :: Nil, + line(stmt), + column(stmt) + )) + .withChild(bodyAst) + .withConditionEdges(foreachNode, conditionAst.root.toList) + end astForForeachStmt + + private def getItemAssignAstForForeach( + stmt: PhpForeachStmt, + assignItemTargetAst: Ast, + iteratorIdentifier: NewIdentifier + ): Ast = + val iteratorIdentifierAst = Ast(iteratorIdentifier) + val currentCallSignature = s"$UnresolvedSignature(0)" + val currentCallCode = s"${iteratorIdentifierAst.rootCodeOrEmpty}->current()" + val currentCallNode = callNode( + stmt, + currentCallCode, + "current", + "Iterator.current", + DispatchTypes.DYNAMIC_DISPATCH, + Some(currentCallSignature), + Some(TypeConstants.Any) + ); + val currentCallAst = callAst(currentCallNode, base = Option(iteratorIdentifierAst)) + + val valueAst = if stmt.assignByRef then + val addressOfCode = s"&${currentCallAst.rootCodeOrEmpty}" + val addressOfCall = + newOperatorCallNode( + Operators.addressOf, + addressOfCode, + line = line(stmt), + column = column(stmt) + ) + callAst(addressOfCall, currentCallAst :: Nil) + else + currentCallAst + + simpleAssignAst(assignItemTargetAst, valueAst, line(stmt), column(stmt)) + end getItemAssignAstForForeach + + private def simpleAssignAst( + target: Ast, + source: Ast, + lineNo: Option[Integer], + colNo: Option[Integer] + ): Ast = + val code = s"${target.rootCodeOrEmpty} = ${source.rootCodeOrEmpty}" + val callNode = + newOperatorCallNode(Operators.assignment, code, line = lineNo, column = colNo) + callAst(callNode, target :: source :: Nil) + + private def astforTraitUseStmt(stmt: PhpTraitUseStmt): Ast = + // TODO Actually implement this + Ast() + + private def astForUseUse(stmt: PhpUseUse, namePrefix: String = ""): Ast = + val originalName = s"$namePrefix${stmt.originalName.name}" + val aliasCode = stmt.alias.map(alias => s" as ${alias.name}").getOrElse("") + val typeCode = stmt.useType match + case PhpUseType.Function => s"function " + case PhpUseType.Constant => s"const " + case _ => "" + val code = s"use $typeCode$originalName$aliasCode" + val importNode = NewImport() + .importedEntity(originalName) + .importedAs(stmt.alias.map(_.name)) + .isExplicit(true) + .code(code) + .lineNumber(line(stmt)) + .columnNumber(column(stmt)) + + Ast(importNode) + + private def astsForStaticStmt(stmt: PhpStaticStmt): List[Ast] = + stmt.vars.flatMap { staticVarDecl => + val variableAst = astForVariableExpr(staticVarDecl.variable) + val maybeValueAst = staticVarDecl.defaultValue.map(astForExpr) + + val code = variableAst.rootCode.getOrElse(NameConstants.Unknown) + val name = variableAst.root match + case Some(identifier: NewIdentifier) => identifier.name + case _ => code + + val local = localNode( stmt, - nextCallCode, - "next", - "Iterator.next", - DispatchTypes.DYNAMIC_DISPATCH, - Some(nextSignature), - Some(TypeConstants.Any) + name, + s"static $code", + variableAst.rootType.getOrElse(TypeConstants.Any) ) - val nextCallAst = callAst(nextCallNode, base = Option(nextIterIdent)) - val itemUpdateAst = itemInitAst.root match - case Some(initRoot: AstNodeNew) => itemInitAst.subTreeCopy(initRoot) - case _ => - logger.debug(s"Could not copy foreach init ast in $filename") - Ast() - - val bodyAst = stmtBodyBlockAst(stmt) - - val ampPrefix = if stmt.assignByRef then "&" else "" - val foreachCode = - s"foreach (${iteratorAst.rootCodeOrEmpty} as $ampPrefix${assignItemTargetAst.rootCodeOrEmpty})" - val foreachNode = controlStructureNode(stmt, ControlStructureTypes.FOR, foreachCode) - Ast(foreachNode) - .withChild(wrapMultipleInBlock( - iteratorAssignAst :: itemInitAst :: Nil, - line(stmt), - column(stmt) - )) - .withChild(conditionAst) - .withChild(wrapMultipleInBlock( - nextCallAst :: itemUpdateAst :: Nil, - line(stmt), - column(stmt) - )) - .withChild(bodyAst) - .withConditionEdges(foreachNode, conditionAst.root.toList) - end astForForeachStmt - - private def getItemAssignAstForForeach( - stmt: PhpForeachStmt, - assignItemTargetAst: Ast, - iteratorIdentifier: NewIdentifier - ): Ast = - val iteratorIdentifierAst = Ast(iteratorIdentifier) - val currentCallSignature = s"$UnresolvedSignature(0)" - val currentCallCode = s"${iteratorIdentifierAst.rootCodeOrEmpty}->current()" - val currentCallNode = callNode( - stmt, - currentCallCode, - "current", - "Iterator.current", - DispatchTypes.DYNAMIC_DISPATCH, - Some(currentCallSignature), - Some(TypeConstants.Any) - ); - val currentCallAst = callAst(currentCallNode, base = Option(iteratorIdentifierAst)) - - val valueAst = if stmt.assignByRef then - val addressOfCode = s"&${currentCallAst.rootCodeOrEmpty}" - val addressOfCall = - newOperatorCallNode( - Operators.addressOf, - addressOfCode, - line = line(stmt), - column = column(stmt) - ) - callAst(addressOfCall, currentCallAst :: Nil) - else - currentCallAst - - simpleAssignAst(assignItemTargetAst, valueAst, line(stmt), column(stmt)) - end getItemAssignAstForForeach - - private def simpleAssignAst( - target: Ast, - source: Ast, - lineNo: Option[Integer], - colNo: Option[Integer] - ): Ast = - val code = s"${target.rootCodeOrEmpty} = ${source.rootCodeOrEmpty}" - val callNode = - newOperatorCallNode(Operators.assignment, code, line = lineNo, column = colNo) - callAst(callNode, target :: source :: Nil) - - private def astforTraitUseStmt(stmt: PhpTraitUseStmt): Ast = - // TODO Actually implement this - Ast() - - private def astForUseUse(stmt: PhpUseUse, namePrefix: String = ""): Ast = - val originalName = s"$namePrefix${stmt.originalName.name}" - val aliasCode = stmt.alias.map(alias => s" as ${alias.name}").getOrElse("") - val typeCode = stmt.useType match - case PhpUseType.Function => s"function " - case PhpUseType.Constant => s"const " - case _ => "" - val code = s"use $typeCode$originalName$aliasCode" - val importNode = NewImport() - .importedEntity(originalName) - .importedAs(stmt.alias.map(_.name)) - .isExplicit(true) - .code(code) - .lineNumber(line(stmt)) - .columnNumber(column(stmt)) - - Ast(importNode) - - private def astsForStaticStmt(stmt: PhpStaticStmt): List[Ast] = - stmt.vars.flatMap { staticVarDecl => - val variableAst = astForVariableExpr(staticVarDecl.variable) - val maybeValueAst = staticVarDecl.defaultValue.map(astForExpr) - - val code = variableAst.rootCode.getOrElse(NameConstants.Unknown) - val name = variableAst.root match - case Some(identifier: NewIdentifier) => identifier.name - case _ => code - - val local = localNode( - stmt, - name, - s"static $code", - variableAst.rootType.getOrElse(TypeConstants.Any) - ) - scope.addToScope(local.name, local) - - variableAst.root.collect { case identifier: NewIdentifier => - diffGraph.addEdge(identifier, local, EdgeTypes.REF) - } - - val defaultAssignAst = maybeValueAst.map { valueAst => - val valueCode = s"static $code = ${valueAst.rootCodeOrEmpty}" - val assignNode = - newOperatorCallNode( - Operators.assignment, - valueCode, - line = line(stmt), - column = column(stmt) - ) - callAst(assignNode, variableAst :: valueAst :: Nil) - } - - Ast(local) :: defaultAssignAst.toList - } + scope.addToScope(local.name, local) - private def astForAnonymousClass(stmt: PhpClassLikeStmt): Ast = - // TODO - Ast() - - def codeForClassStmt(stmt: PhpClassLikeStmt, name: PhpNameExpr): String = - // TODO Extend for anonymous classes - val extendsString = stmt.extendsNames match - case Nil => "" - case names => s" extends ${names.map(_.name).mkString(", ")}" - val implementsString = - if stmt.implementedInterfaces.isEmpty then - "" - else - s" implements ${stmt.implementedInterfaces.map(_.name).mkString(", ")}" - - s"${stmt.classLikeType} ${name.name}$extendsString$implementsString" - - private def astForNamedClass(stmt: PhpClassLikeStmt, name: PhpNameExpr): Ast = - val inheritsFrom = (stmt.extendsNames ++ stmt.implementedInterfaces).map(_.name) - val code = codeForClassStmt(stmt, name) - - val fullName = - if name.name == NamespaceTraversal.globalNamespaceName then - globalNamespace.fullName - else - prependNamespacePrefix(name.name) - - val typeDecl = - typeDeclNode(stmt, name.name, fullName, filename, code, inherits = inheritsFrom) - - val createDefaultConstructor = stmt.hasConstructor - - scope.pushNewScope(typeDecl) - val bodyStmts = astsForClassLikeBody(stmt, stmt.stmts, createDefaultConstructor) - val modifiers = stmt.modifiers.map(newModifierNode).map(Ast(_)) - scope.popScope() - - Ast(typeDecl).withChildren(modifiers).withChildren(bodyStmts) - end astForNamedClass - - private def astForStaticAndConstInits: Option[Ast] = - scope.getConstAndStaticInits match - case Nil => None - - case inits => - val signature = s"${TypeConstants.Void}()" - val fullName = composeMethodFullName(StaticInitMethodName, isStatic = true) - val ast = staticInitMethodAst( - inits, - fullName, - Option(signature), - TypeConstants.Void, - fileName = Some(filename) - ) - Option(ast) - - private def astsForClassLikeBody( - classLike: PhpStmt, - bodyStmts: List[PhpStmt], - createDefaultConstructor: Boolean - ): List[Ast] = - val classConsts = - bodyStmts.collect { case cs: PhpConstStmt => cs }.flatMap(astsForConstStmt) - val properties = - bodyStmts.collect { case cp: PhpPropertyStmt => cp }.flatMap(astsForPropertyStmt) - - val explicitConstructorAst = bodyStmts.collectFirst { - case m: PhpMethodDecl if m.name.name == ConstructorMethodName => astForConstructor(m) + variableAst.root.collect { case identifier: NewIdentifier => + diffGraph.addEdge(identifier, local, EdgeTypes.REF) } - val constructorAst = - explicitConstructorAst.orElse( - Option.when(createDefaultConstructor)(defaultConstructorAst(classLike)) - ) - - val otherBodyStmts = bodyStmts.flatMap { - case _: PhpConstStmt => Nil // Handled above + val defaultAssignAst = maybeValueAst.map { valueAst => + val valueCode = s"static $code = ${valueAst.rootCodeOrEmpty}" + val assignNode = + newOperatorCallNode( + Operators.assignment, + valueCode, + line = line(stmt), + column = column(stmt) + ) + callAst(assignNode, variableAst :: valueAst :: Nil) + } - case _: PhpPropertyStmt => Nil // Handled above + Ast(local) :: defaultAssignAst.toList + } + + private def astForAnonymousClass(stmt: PhpClassLikeStmt): Ast = + // TODO + Ast() + + def codeForClassStmt(stmt: PhpClassLikeStmt, name: PhpNameExpr): String = + // TODO Extend for anonymous classes + val extendsString = stmt.extendsNames match + case Nil => "" + case names => s" extends ${names.map(_.name).mkString(", ")}" + val implementsString = + if stmt.implementedInterfaces.isEmpty then + "" + else + s" implements ${stmt.implementedInterfaces.map(_.name).mkString(", ")}" - case method: PhpMethodDecl if method.name.name == ConstructorMethodName => - Nil // Handled above + s"${stmt.classLikeType} ${name.name}$extendsString$implementsString" - // Not all statements are supported in class bodies, but since this is re-used for namespaces - // we allow that here. - case stmt => astsForStmt(stmt) - } + private def astForNamedClass(stmt: PhpClassLikeStmt, name: PhpNameExpr): Ast = + val inheritsFrom = (stmt.extendsNames ++ stmt.implementedInterfaces).map(_.name) + val code = codeForClassStmt(stmt, name) - val clinitAst = astForStaticAndConstInits - val anonymousMethodAsts = scope.getAndClearAnonymousMethods - - List( - classConsts, - properties, - clinitAst, - constructorAst, - anonymousMethodAsts, - otherBodyStmts - ).flatten - end astsForClassLikeBody - - private def astForConstructor(constructorDecl: PhpMethodDecl): Ast = - val fieldInits = scope.getFieldInits - astForMethodDecl(constructorDecl, fieldInits, isConstructor = true) - - private def prependNamespacePrefix(name: String): String = - scope.getEnclosingNamespaceNames.filterNot( - _ == NamespaceTraversal.globalNamespaceName - ) match - case Nil => name - case names => names.appended(name).mkString(NamespaceDelimiter) - - private def getTypeDeclPrefix: Option[String] = - scope.getEnclosingTypeDeclTypeName - .filterNot(_ == NamespaceTraversal.globalNamespaceName) - - private def defaultConstructorAst(originNode: PhpNode): Ast = - val fullName = composeMethodFullName(ConstructorMethodName, isStatic = false) - - val signature = s"$UnresolvedSignature(0)" - - val modifiers = List( - ModifierTypes.VIRTUAL, - ModifierTypes.PUBLIC, - ModifierTypes.CONSTRUCTOR - ).map(newModifierNode) - - val thisParam = thisParamAstForMethod(originNode) - - val method = methodNode( - originNode, - ConstructorMethodName, - fullName, - fullName, - Some(signature), - filename - ) - - val methodBody = blockAst(blockNode(originNode), scope.getFieldInits) - - val methodReturn = newMethodReturnNode(TypeConstants.Any, line = None, column = None) - - methodAstWithAnnotations(method, thisParam :: Nil, methodBody, methodReturn, modifiers) - end defaultConstructorAst - - private def astForMemberAssignment( - memberNode: NewMember, - valueExpr: PhpExpr, - isField: Boolean - ): Ast = - val targetAst = if isField then - val code = s"$$this->${memberNode.name}" - val fieldAccessNode = - newOperatorCallNode(Operators.fieldAccess, code, line = memberNode.lineNumber) - val identifier = thisIdentifier(memberNode.lineNumber) - val thisParam = scope.lookupVariable(NameConstants.This) - val fieldIdentifier = newFieldIdentifierNode(memberNode.name, memberNode.lineNumber) - callAst(fieldAccessNode, List(identifier, fieldIdentifier).map(Ast(_))).withRefEdges( - identifier, - thisParam.toList - ) + val fullName = + if name.name == NamespaceTraversal.globalNamespaceName then + globalNamespace.fullName else - val identifierCode = memberNode.code.replaceAll("const ", "").replaceAll("case ", "") - val typeFullName = Option(memberNode.typeFullName) - val identifier = newIdentifierNode(memberNode.name, typeFullName.getOrElse("ANY")) - .code(identifierCode) - Ast(identifier).withRefEdge(identifier, memberNode) - val value = astForExpr(valueExpr) - - val assignmentCode = s"${targetAst.rootCodeOrEmpty} = ${value.rootCodeOrEmpty}" - val callNode = - newOperatorCallNode(Operators.assignment, assignmentCode, line = memberNode.lineNumber) - - callAst(callNode, List(targetAst, value)) - end astForMemberAssignment - - private def astsForConstStmt(stmt: PhpConstStmt): List[Ast] = - stmt.consts.map { constDecl => - val finalModifier = Ast(newModifierNode(ModifierTypes.FINAL)) - // `final const` is not allowed, so this is a safe way to represent constants in the CPG - val modifierAsts = finalModifier :: stmt.modifiers.map(newModifierNode).map(Ast(_)) - - val name = constDecl.name.name - val code = s"const $name" - val someValue = Option(constDecl.value) - astForConstOrFieldValue( - stmt, - name, - code, - someValue, - scope.addConstOrStaticInitToScope, - isField = false + prependNamespacePrefix(name.name) + + val typeDecl = + typeDeclNode(stmt, name.name, fullName, filename, code, inherits = inheritsFrom) + + val createDefaultConstructor = stmt.hasConstructor + + scope.pushNewScope(typeDecl) + val bodyStmts = astsForClassLikeBody(stmt, stmt.stmts, createDefaultConstructor) + val modifiers = stmt.modifiers.map(newModifierNode).map(Ast(_)) + scope.popScope() + + Ast(typeDecl).withChildren(modifiers).withChildren(bodyStmts) + end astForNamedClass + + private def astForStaticAndConstInits: Option[Ast] = + scope.getConstAndStaticInits match + case Nil => None + + case inits => + val signature = s"${TypeConstants.Void}()" + val fullName = composeMethodFullName(StaticInitMethodName, isStatic = true) + val ast = staticInitMethodAst( + inits, + fullName, + Option(signature), + TypeConstants.Void, + fileName = Some(filename) ) - .withChildren(modifierAsts) - } + Option(ast) + + private def astsForClassLikeBody( + classLike: PhpStmt, + bodyStmts: List[PhpStmt], + createDefaultConstructor: Boolean + ): List[Ast] = + val classConsts = + bodyStmts.collect { case cs: PhpConstStmt => cs }.flatMap(astsForConstStmt) + val properties = + bodyStmts.collect { case cp: PhpPropertyStmt => cp }.flatMap(astsForPropertyStmt) + + val explicitConstructorAst = bodyStmts.collectFirst { + case m: PhpMethodDecl if m.name.name == ConstructorMethodName => astForConstructor(m) + } + + val constructorAst = + explicitConstructorAst.orElse( + Option.when(createDefaultConstructor)(defaultConstructorAst(classLike)) + ) - private def astForEnumCase(stmt: PhpEnumCaseStmt): Ast = - val finalModifier = Ast(newModifierNode(ModifierTypes.FINAL)) + val otherBodyStmts = bodyStmts.flatMap { + case _: PhpConstStmt => Nil // Handled above + + case _: PhpPropertyStmt => Nil // Handled above + + case method: PhpMethodDecl if method.name.name == ConstructorMethodName => + Nil // Handled above + + // Not all statements are supported in class bodies, but since this is re-used for namespaces + // we allow that here. + case stmt => astsForStmt(stmt) + } + + val clinitAst = astForStaticAndConstInits + val anonymousMethodAsts = scope.getAndClearAnonymousMethods + + List( + classConsts, + properties, + clinitAst, + constructorAst, + anonymousMethodAsts, + otherBodyStmts + ).flatten + end astsForClassLikeBody + + private def astForConstructor(constructorDecl: PhpMethodDecl): Ast = + val fieldInits = scope.getFieldInits + astForMethodDecl(constructorDecl, fieldInits, isConstructor = true) + + private def prependNamespacePrefix(name: String): String = + scope.getEnclosingNamespaceNames.filterNot( + _ == NamespaceTraversal.globalNamespaceName + ) match + case Nil => name + case names => names.appended(name).mkString(NamespaceDelimiter) + + private def getTypeDeclPrefix: Option[String] = + scope.getEnclosingTypeDeclTypeName + .filterNot(_ == NamespaceTraversal.globalNamespaceName) + + private def defaultConstructorAst(originNode: PhpNode): Ast = + val fullName = composeMethodFullName(ConstructorMethodName, isStatic = false) + + val signature = s"$UnresolvedSignature(0)" + + val modifiers = List( + ModifierTypes.VIRTUAL, + ModifierTypes.PUBLIC, + ModifierTypes.CONSTRUCTOR + ).map(newModifierNode) + + val thisParam = thisParamAstForMethod(originNode) + + val method = methodNode( + originNode, + ConstructorMethodName, + fullName, + fullName, + Some(signature), + filename + ) - val name = stmt.name.name - val code = s"case $name" + val methodBody = blockAst(blockNode(originNode), scope.getFieldInits) + + val methodReturn = newMethodReturnNode(TypeConstants.Any, line = None, column = None) + + methodAstWithAnnotations(method, thisParam :: Nil, methodBody, methodReturn, modifiers) + end defaultConstructorAst + + private def astForMemberAssignment( + memberNode: NewMember, + valueExpr: PhpExpr, + isField: Boolean + ): Ast = + val targetAst = if isField then + val code = s"$$this->${memberNode.name}" + val fieldAccessNode = + newOperatorCallNode(Operators.fieldAccess, code, line = memberNode.lineNumber) + val identifier = thisIdentifier(memberNode.lineNumber) + val thisParam = scope.lookupVariable(NameConstants.This) + val fieldIdentifier = newFieldIdentifierNode(memberNode.name, memberNode.lineNumber) + callAst(fieldAccessNode, List(identifier, fieldIdentifier).map(Ast(_))).withRefEdges( + identifier, + thisParam.toList + ) + else + val identifierCode = memberNode.code.replaceAll("const ", "").replaceAll("case ", "") + val typeFullName = Option(memberNode.typeFullName) + val identifier = newIdentifierNode(memberNode.name, typeFullName.getOrElse("ANY")) + .code(identifierCode) + Ast(identifier).withRefEdge(identifier, memberNode) + val value = astForExpr(valueExpr) + + val assignmentCode = s"${targetAst.rootCodeOrEmpty} = ${value.rootCodeOrEmpty}" + val callNode = + newOperatorCallNode(Operators.assignment, assignmentCode, line = memberNode.lineNumber) + + callAst(callNode, List(targetAst, value)) + end astForMemberAssignment + + private def astsForConstStmt(stmt: PhpConstStmt): List[Ast] = + stmt.consts.map { constDecl => + val finalModifier = Ast(newModifierNode(ModifierTypes.FINAL)) + // `final const` is not allowed, so this is a safe way to represent constants in the CPG + val modifierAsts = finalModifier :: stmt.modifiers.map(newModifierNode).map(Ast(_)) + val name = constDecl.name.name + val code = s"const $name" + val someValue = Option(constDecl.value) astForConstOrFieldValue( stmt, name, code, - stmt.expr, + someValue, scope.addConstOrStaticInitToScope, isField = false ) - .withChild(finalModifier) - - private def astsForPropertyStmt(stmt: PhpPropertyStmt): List[Ast] = - stmt.variables.map { varDecl => - val modifierAsts = stmt.modifiers.map(newModifierNode).map(Ast(_)) - - val name = varDecl.name.name - astForConstOrFieldValue( - stmt, - name, - s"$$$name", - varDecl.defaultValue, - scope.addFieldInitToScope, - isField = true - ) - .withChildren(modifierAsts) - } + .withChildren(modifierAsts) + } + + private def astForEnumCase(stmt: PhpEnumCaseStmt): Ast = + val finalModifier = Ast(newModifierNode(ModifierTypes.FINAL)) + + val name = stmt.name.name + val code = s"case $name" + + astForConstOrFieldValue( + stmt, + name, + code, + stmt.expr, + scope.addConstOrStaticInitToScope, + isField = false + ) + .withChild(finalModifier) - private def astForConstOrFieldValue( - originNode: PhpNode, - name: String, - code: String, - value: Option[PhpExpr], - addToScope: Ast => Unit, - isField: Boolean - ): Ast = - val member = memberNode(originNode, name, code, TypeConstants.Any) - - value match - case Some(v) => - val assignAst = astForMemberAssignment(member, v, isField) - addToScope(assignAst) - case None => // Nothing to do here - Ast(member) - - private def astForCatchStmt(stmt: PhpCatchStmt): Ast = - // TODO Add variable at some point. Current implementation is consistent with C++. - stmtBodyBlockAst(stmt) - - private def astsForSwitchCase(caseStmt: PhpCaseStmt): List[Ast] = - val maybeConditionAst = caseStmt.condition.map(astForExpr) - val jumpTarget = maybeConditionAst match - case Some(conditionAst) => - NewJumpTarget().name("case").code(s"case ${conditionAst.rootCodeOrEmpty}") - case None => NewJumpTarget().name("default").code("default") - jumpTarget.lineNumber(line(caseStmt)).columnNumber(column(caseStmt)) - - val stmtAsts = caseStmt.stmts.flatMap(astsForStmt) - - Ast(jumpTarget) :: stmtAsts - - private def codeForMethodCall(call: PhpCallExpr, targetAst: Ast, name: String): String = - val callOperator = if call.isNullSafe then "?->" else "->" - s"${targetAst.rootCodeOrEmpty}$callOperator$name" - - private def codeForStaticMethodCall(call: PhpCallExpr, name: String): String = - val className = - call.target - .map(astForExpr) - .map(_.rootCode.getOrElse(UnresolvedNamespace)) - .getOrElse(UnresolvedNamespace) - s"$className::$name" - - private def astForCall(call: PhpCallExpr): Ast = - val arguments = call.args.map(astForCallArg) - - val targetAst = Option.unless(call.isStatic)(call.target.map(astForExpr)).flatten - - val nameAst = - Option.unless(call.methodName.isInstanceOf[PhpNameExpr])(astForExpr(call.methodName)) - val name = - nameAst - .map(_.rootCodeOrEmpty) - .getOrElse(call.methodName match - case nameExpr: PhpNameExpr => nameExpr.name - case other => - logger.debug( - s"Found unexpected call target type: Crash for now to handle properly later: $other" - ) - ??? - ) + private def astsForPropertyStmt(stmt: PhpPropertyStmt): List[Ast] = + stmt.variables.map { varDecl => + val modifierAsts = stmt.modifiers.map(newModifierNode).map(Ast(_)) - val argsCode = arguments - .zip(call.args.collect { case x: PhpArg => x.unpack }) - .map { - case (arg, true) => s"...${arg.rootCodeOrEmpty}" - case (arg, false) => arg.rootCodeOrEmpty - } - .mkString(",") - - val codePrefix = - if !call.isStatic && targetAst.isDefined then - codeForMethodCall(call, targetAst.get, name) - else if call.isStatic then - codeForStaticMethodCall(call, name) - else - name - - val code = s"$codePrefix($argsCode)" - - val dispatchType = - if call.isStatic || call.target.isEmpty then - DispatchTypes.STATIC_DISPATCH - else - DispatchTypes.DYNAMIC_DISPATCH - - val fullName = call.target match - // Static method call with a known class name - case Some(nameExpr: PhpNameExpr) if call.isStatic => - if nameExpr.name == "self" then composeMethodFullName(name, call.isStatic) - else s"${nameExpr.name}${StaticMethodDelimiter}$name" - - case Some(expr) => - s"$UnresolvedNamespace\\$codePrefix" - - case None if PhpBuiltins.FuncNames.contains(name) => - // No signature/namespace for MFN for builtin functions to ensure stable names as type info improves. - name - - // Function call - case None => - composeMethodFullName(name, call.isStatic) - - // Use method signature for methods that can be linked to avoid varargs issue. - val signature = s"$UnresolvedSignature(${call.args.size})" - val callRoot = callNode( - call, - code, + val name = varDecl.name.name + astForConstOrFieldValue( + stmt, name, - fullName, - dispatchType, - Some(signature), - Some(TypeConstants.Any) + s"$$$name", + varDecl.defaultValue, + scope.addFieldInitToScope, + isField = true ) + .withChildren(modifierAsts) + } + + private def astForConstOrFieldValue( + originNode: PhpNode, + name: String, + code: String, + value: Option[PhpExpr], + addToScope: Ast => Unit, + isField: Boolean + ): Ast = + val member = memberNode(originNode, name, code, TypeConstants.Any) + + value match + case Some(v) => + val assignAst = astForMemberAssignment(member, v, isField) + addToScope(assignAst) + case None => // Nothing to do here + Ast(member) + + private def astForCatchStmt(stmt: PhpCatchStmt): Ast = + // TODO Add variable at some point. Current implementation is consistent with C++. + stmtBodyBlockAst(stmt) + + private def astsForSwitchCase(caseStmt: PhpCaseStmt): List[Ast] = + val maybeConditionAst = caseStmt.condition.map(astForExpr) + val jumpTarget = maybeConditionAst match + case Some(conditionAst) => + NewJumpTarget().name("case").code(s"case ${conditionAst.rootCodeOrEmpty}") + case None => NewJumpTarget().name("default").code("default") + jumpTarget.lineNumber(line(caseStmt)).columnNumber(column(caseStmt)) + + val stmtAsts = caseStmt.stmts.flatMap(astsForStmt) + + Ast(jumpTarget) :: stmtAsts + + private def codeForMethodCall(call: PhpCallExpr, targetAst: Ast, name: String): String = + val callOperator = if call.isNullSafe then "?->" else "->" + s"${targetAst.rootCodeOrEmpty}$callOperator$name" + + private def codeForStaticMethodCall(call: PhpCallExpr, name: String): String = + val className = + call.target + .map(astForExpr) + .map(_.rootCode.getOrElse(UnresolvedNamespace)) + .getOrElse(UnresolvedNamespace) + s"$className::$name" + + private def astForCall(call: PhpCallExpr): Ast = + val arguments = call.args.map(astForCallArg) + + val targetAst = Option.unless(call.isStatic)(call.target.map(astForExpr)).flatten + + val nameAst = + Option.unless(call.methodName.isInstanceOf[PhpNameExpr])(astForExpr(call.methodName)) + val name = + nameAst + .map(_.rootCodeOrEmpty) + .getOrElse(call.methodName match + case nameExpr: PhpNameExpr => nameExpr.name + case other => + logger.debug( + s"Found unexpected call target type: Crash for now to handle properly later: $other" + ) + ??? + ) - val receiverAst = (targetAst, nameAst) match - case (Some(target), Some(n)) => - val fieldAccess = - newOperatorCallNode( - Operators.fieldAccess, - codePrefix, - line = line(call), - column = column(call) - ) - Option(callAst(fieldAccess, target :: n :: Nil)) - case (Some(target), None) => Option(target) - case (None, Some(n)) => Option(n) - case (None, None) => None - - callAst(callRoot, arguments, base = receiverAst) - end astForCall - - private def astForCallArg(arg: PhpArgument): Ast = - arg match - case PhpArg(expr, _, _, _, _) => - astForExpr(expr) - - case _: PhpVariadicPlaceholder => - val identifier = - identifierNode(arg, "...", "...", TypeConstants.VariadicPlaceholder) - Ast(identifier) - - private def astForVariableExpr(variable: PhpVariable): Ast = - // TODO Need to figure out variable variables. Maybe represent as some kind of call? - val valueAst = astForExpr(variable.value) - - valueAst.root.collect { case root: ExpressionNew => - root.code = s"$$${root.code}" + val argsCode = arguments + .zip(call.args.collect { case x: PhpArg => x.unpack }) + .map { + case (arg, true) => s"...${arg.rootCodeOrEmpty}" + case (arg, false) => arg.rootCodeOrEmpty } + .mkString(",") - valueAst.root.collect { case root: NewIdentifier => - root.lineNumber = line(variable) - root.columnNumber = column(variable) - } + val codePrefix = + if !call.isStatic && targetAst.isDefined then + codeForMethodCall(call, targetAst.get, name) + else if call.isStatic then + codeForStaticMethodCall(call, name) + else + name - valueAst + val code = s"$codePrefix($argsCode)" - private def astForNameExpr(expr: PhpNameExpr): Ast = - val identifier = identifierNode(expr, expr.name, expr.name, TypeConstants.Any) + val dispatchType = + if call.isStatic || call.target.isEmpty then + DispatchTypes.STATIC_DISPATCH + else + DispatchTypes.DYNAMIC_DISPATCH + + val fullName = call.target match + // Static method call with a known class name + case Some(nameExpr: PhpNameExpr) if call.isStatic => + if nameExpr.name == "self" then composeMethodFullName(name, call.isStatic) + else s"${nameExpr.name}${StaticMethodDelimiter}$name" + + case Some(expr) => + s"$UnresolvedNamespace\\$codePrefix" + + case None if PhpBuiltins.FuncNames.contains(name) => + // No signature/namespace for MFN for builtin functions to ensure stable names as type info improves. + name + + // Function call + case None => + composeMethodFullName(name, call.isStatic) + + // Use method signature for methods that can be linked to avoid varargs issue. + val signature = s"$UnresolvedSignature(${call.args.size})" + val callRoot = callNode( + call, + code, + name, + fullName, + dispatchType, + Some(signature), + Some(TypeConstants.Any) + ) - scope.lookupVariable(identifier.name).foreach { declaringNode => - diffGraph.addEdge(identifier, declaringNode, EdgeTypes.REF) - } + val receiverAst = (targetAst, nameAst) match + case (Some(target), Some(n)) => + val fieldAccess = + newOperatorCallNode( + Operators.fieldAccess, + codePrefix, + line = line(call), + column = column(call) + ) + Option(callAst(fieldAccess, target :: n :: Nil)) + case (Some(target), None) => Option(target) + case (None, Some(n)) => Option(n) + case (None, None) => None + + callAst(callRoot, arguments, base = receiverAst) + end astForCall + + private def astForCallArg(arg: PhpArgument): Ast = + arg match + case PhpArg(expr, _, _, _, _) => + astForExpr(expr) + + case _: PhpVariadicPlaceholder => + val identifier = + identifierNode(arg, "...", "...", TypeConstants.VariadicPlaceholder) + Ast(identifier) + + private def astForVariableExpr(variable: PhpVariable): Ast = + // TODO Need to figure out variable variables. Maybe represent as some kind of call? + val valueAst = astForExpr(variable.value) + + valueAst.root.collect { case root: ExpressionNew => + root.code = s"$$${root.code}" + } + + valueAst.root.collect { case root: NewIdentifier => + root.lineNumber = line(variable) + root.columnNumber = column(variable) + } + + valueAst + + private def astForNameExpr(expr: PhpNameExpr): Ast = + val identifier = identifierNode(expr, expr.name, expr.name, TypeConstants.Any) + + scope.lookupVariable(identifier.name).foreach { declaringNode => + diffGraph.addEdge(identifier, declaringNode, EdgeTypes.REF) + } + + Ast(identifier) + + /** This is used to rewrite the short form $xs[] = as array_push($xs, ) + * to avoid having to handle the empty array access operator as a special case in the dataflow + * engine. + * + * This representation is technically wrong in the case where the shorthand is used to initialise + * a new array (since PHP expects the first argument to array_push to be an existing array). This + * shouldn't affect dataflow, however. + */ + private def astForEmptyArrayDimAssign( + assignment: PhpAssignment, + arrayDimFetch: PhpArrayDimFetchExpr + ): Ast = + val attrs = assignment.attributes + val arrayPushArgs = List(arrayDimFetch.variable, assignment.source).map(PhpArg(_)) + val arrayPushCall = PhpCallExpr( + target = None, + methodName = PhpNameExpr("array_push", attrs), + args = arrayPushArgs, + isNullSafe = false, + isStatic = true, + attributes = attrs + ) + val arrayPushAst = astForCall(arrayPushCall) + arrayPushAst.root.collect { case astRoot: NewCall => + val args = + arrayPushAst.argEdges + .filter(_.src == astRoot) + .map(_.dst) + .collect { case arg: ExpressionNew => arg } + .sortBy(_.argumentIndex) + + if args.size != 2 then + val position = s"${line(assignment).getOrElse("")}:${filename}" + logger.debug( + s"Expected 2 call args for emptyArrayDimAssign. Not resetting code: ${position}" + ) + else + val codeOverride = s"${args.head.code}[] = ${args.last.code}" + astRoot.code(codeOverride) + } + arrayPushAst + end astForEmptyArrayDimAssign + + private def astForAssignment(assignment: PhpAssignment): Ast = + assignment.target match + case arrayDimFetch: PhpArrayDimFetchExpr if arrayDimFetch.dimension.isEmpty => + // Rewrite `$xs[] = ` as `array_push($xs, )` to simplify finding dataflows. + astForEmptyArrayDimAssign(assignment, arrayDimFetch) + + case _ => + val operatorName = assignment.assignOp + + val targetAst = astForExpr(assignment.target) + val sourceAst = astForExpr(assignment.source) + + // TODO Handle ref assigns properly (if needed). + val refSymbol = if assignment.isRefAssign then "&" else "" + val symbol = operatorSymbols.getOrElse(assignment.assignOp, assignment.assignOp) + val code = + s"${targetAst.rootCodeOrEmpty} $symbol $refSymbol${sourceAst.rootCodeOrEmpty}" + + val callNode = newOperatorCallNode( + operatorName, + code, + line = line(assignment), + column = column(assignment) + ) + callAst(callNode, List(targetAst, sourceAst)) + + private def astForEncapsed(encapsed: PhpEncapsed): Ast = + val args = encapsed.parts.map(astForExpr) + val code = args.map(_.rootCodeOrEmpty).mkString(" . ") + + args match + case singleArg :: Nil => singleArg + case _ => + val callNode = newOperatorCallNode( + PhpOperators.encaps, + code, + Some(TypeConstants.String), + line(encapsed), + column = column(encapsed) + ) + callAst(callNode, args) + + private def astForScalar(scalar: PhpScalar): Ast = + scalar match + case encapsed: PhpEncapsed => astForEncapsed(encapsed) + case simpleScalar: PhpSimpleScalar => + Ast(literalNode(scalar, simpleScalar.value, simpleScalar.typeFullName)) + case null => + logger.debug("scalar was null") + ??? + + private def astForBinOp(binOp: PhpBinaryOp): Ast = + val leftAst = astForExpr(binOp.left) + val rightAst = astForExpr(binOp.right) + + val symbol = operatorSymbols.getOrElse(binOp.operator, binOp.operator) + val code = s"${leftAst.rootCodeOrEmpty} $symbol ${rightAst.rootCodeOrEmpty}" + + val callNode = + newOperatorCallNode(binOp.operator, code, line = line(binOp), column = column(binOp)) + + callAst(callNode, List(leftAst, rightAst)) + + private def isPostfixOperator(operator: String): Boolean = + Set(Operators.postDecrement, Operators.postIncrement).contains(operator) + + private def astForUnaryOp(unaryOp: PhpUnaryOp): Ast = + val exprAst = astForExpr(unaryOp.expr) + + val symbol = operatorSymbols.getOrElse(unaryOp.operator, unaryOp.operator) + val code = + if isPostfixOperator(unaryOp.operator) then + s"${exprAst.rootCodeOrEmpty}$symbol" + else + s"$symbol${exprAst.rootCodeOrEmpty}" - Ast(identifier) - - /** This is used to rewrite the short form $xs[] = as array_push($xs, ) - * to avoid having to handle the empty array access operator as a special case in the dataflow - * engine. - * - * This representation is technically wrong in the case where the shorthand is used to - * initialise a new array (since PHP expects the first argument to array_push to be an existing - * array). This shouldn't affect dataflow, however. - */ - private def astForEmptyArrayDimAssign( - assignment: PhpAssignment, - arrayDimFetch: PhpArrayDimFetchExpr - ): Ast = - val attrs = assignment.attributes - val arrayPushArgs = List(arrayDimFetch.variable, assignment.source).map(PhpArg(_)) - val arrayPushCall = PhpCallExpr( - target = None, - methodName = PhpNameExpr("array_push", attrs), - args = arrayPushArgs, - isNullSafe = false, - isStatic = true, - attributes = attrs - ) - val arrayPushAst = astForCall(arrayPushCall) - arrayPushAst.root.collect { case astRoot: NewCall => - val args = - arrayPushAst.argEdges - .filter(_.src == astRoot) - .map(_.dst) - .collect { case arg: ExpressionNew => arg } - .sortBy(_.argumentIndex) - - if args.size != 2 then - val position = s"${line(assignment).getOrElse("")}:${filename}" - logger.debug( - s"Expected 2 call args for emptyArrayDimAssign. Not resetting code: ${position}" - ) - else - val codeOverride = s"${args.head.code}[] = ${args.last.code}" - astRoot.code(codeOverride) - } - arrayPushAst - end astForEmptyArrayDimAssign - - private def astForAssignment(assignment: PhpAssignment): Ast = - assignment.target match - case arrayDimFetch: PhpArrayDimFetchExpr if arrayDimFetch.dimension.isEmpty => - // Rewrite `$xs[] = ` as `array_push($xs, )` to simplify finding dataflows. - astForEmptyArrayDimAssign(assignment, arrayDimFetch) - - case _ => - val operatorName = assignment.assignOp - - val targetAst = astForExpr(assignment.target) - val sourceAst = astForExpr(assignment.source) - - // TODO Handle ref assigns properly (if needed). - val refSymbol = if assignment.isRefAssign then "&" else "" - val symbol = operatorSymbols.getOrElse(assignment.assignOp, assignment.assignOp) - val code = - s"${targetAst.rootCodeOrEmpty} $symbol $refSymbol${sourceAst.rootCodeOrEmpty}" - - val callNode = newOperatorCallNode( - operatorName, - code, - line = line(assignment), - column = column(assignment) - ) - callAst(callNode, List(targetAst, sourceAst)) - - private def astForEncapsed(encapsed: PhpEncapsed): Ast = - val args = encapsed.parts.map(astForExpr) - val code = args.map(_.rootCodeOrEmpty).mkString(" . ") - - args match - case singleArg :: Nil => singleArg - case _ => - val callNode = newOperatorCallNode( - PhpOperators.encaps, - code, - Some(TypeConstants.String), - line(encapsed), - column = column(encapsed) - ) - callAst(callNode, args) + val callNode = newOperatorCallNode( + unaryOp.operator, + code, + line = line(unaryOp), + column = column(unaryOp) + ) - private def astForScalar(scalar: PhpScalar): Ast = - scalar match - case encapsed: PhpEncapsed => astForEncapsed(encapsed) - case simpleScalar: PhpSimpleScalar => - Ast(literalNode(scalar, simpleScalar.value, simpleScalar.typeFullName)) - case null => - logger.debug("scalar was null") - ??? + callAst(callNode, exprAst :: Nil) - private def astForBinOp(binOp: PhpBinaryOp): Ast = - val leftAst = astForExpr(binOp.left) - val rightAst = astForExpr(binOp.right) + private def astForCastExpr(castExpr: PhpCast): Ast = + val typeFullName = castExpr.typ + val typ = typeRefNode(castExpr, typeFullName, typeFullName) - val symbol = operatorSymbols.getOrElse(binOp.operator, binOp.operator) - val code = s"${leftAst.rootCodeOrEmpty} $symbol ${rightAst.rootCodeOrEmpty}" + val expr = astForExpr(castExpr.expr) + val codeStr = s"($typeFullName) ${expr.rootCodeOrEmpty}" - val callNode = - newOperatorCallNode(binOp.operator, code, line = line(binOp), column = column(binOp)) + val callNode = + newOperatorCallNode( + name = Operators.cast, + codeStr, + Some(typeFullName), + line(castExpr), + column = column(castExpr) + ) - callAst(callNode, List(leftAst, rightAst)) + callAst(callNode, Ast(typ) :: expr :: Nil) - private def isPostfixOperator(operator: String): Boolean = - Set(Operators.postDecrement, Operators.postIncrement).contains(operator) + private def astForIsSetExpr(isSetExpr: PhpIsset): Ast = + val name = PhpOperators.issetFunc + val args = isSetExpr.vars.map(astForExpr) + val code = s"$name(${args.map(_.rootCodeOrEmpty).mkString(",")})" - private def astForUnaryOp(unaryOp: PhpUnaryOp): Ast = - val exprAst = astForExpr(unaryOp.expr) + val callNode = + newOperatorCallNode( + name, + code, + typeFullName = Some(TypeConstants.Bool), + line = line(isSetExpr), + column = column(isSetExpr) + ) + .methodFullName(PhpOperators.issetFunc) - val symbol = operatorSymbols.getOrElse(unaryOp.operator, unaryOp.operator) - val code = - if isPostfixOperator(unaryOp.operator) then - s"${exprAst.rootCodeOrEmpty}$symbol" - else - s"$symbol${exprAst.rootCodeOrEmpty}" + callAst(callNode, args) + private def astForPrintExpr(printExpr: PhpPrint): Ast = + val name = PhpOperators.printFunc + val arg = astForExpr(printExpr.expr) + val code = s"$name(${arg.rootCodeOrEmpty})" - val callNode = newOperatorCallNode( - unaryOp.operator, + val callNode = + newOperatorCallNode( + name, code, - line = line(unaryOp), - column = column(unaryOp) + typeFullName = Some(TypeConstants.Int), + line = line(printExpr), + column = column(printExpr) ) + .methodFullName(PhpOperators.printFunc) + + callAst(callNode, arg :: Nil) + + private def astForTernaryOp(ternaryOp: PhpTernaryOp): Ast = + val conditionAst = astForExpr(ternaryOp.condition) + val maybeThenAst = ternaryOp.thenExpr.map(astForExpr) + val elseAst = astForExpr(ternaryOp.elseExpr) + + val operatorName = + if maybeThenAst.isDefined then Operators.conditional else PhpOperators.elvisOp + val code = maybeThenAst match + case Some(thenAst) => + s"${conditionAst.rootCodeOrEmpty} ? ${thenAst.rootCodeOrEmpty} : ${elseAst.rootCodeOrEmpty}" + case None => s"${conditionAst.rootCodeOrEmpty} ?: ${elseAst.rootCodeOrEmpty}" + + val callNode = newOperatorCallNode( + operatorName, + code, + line = line(ternaryOp), + column = column(ternaryOp) + ) - callAst(callNode, exprAst :: Nil) + val args = List(Option(conditionAst), maybeThenAst, Option(elseAst)).flatten + callAst(callNode, args) + end astForTernaryOp - private def astForCastExpr(castExpr: PhpCast): Ast = - val typeFullName = castExpr.typ - val typ = typeRefNode(castExpr, typeFullName, typeFullName) + private def astForThrow(expr: PhpThrowExpr): Ast = + val thrownExpr = astForExpr(expr.expr) + val code = s"throw ${thrownExpr.rootCodeOrEmpty}" - val expr = astForExpr(castExpr.expr) - val codeStr = s"($typeFullName) ${expr.rootCodeOrEmpty}" + val throwNode = controlStructureNode(expr, ControlStructureTypes.THROW, code) - val callNode = - newOperatorCallNode( - name = Operators.cast, - codeStr, - Some(typeFullName), - line(castExpr), - column = column(castExpr) - ) + Ast(throwNode).withChild(thrownExpr) - callAst(callNode, Ast(typ) :: expr :: Nil) + private def astForClone(expr: PhpCloneExpr): Ast = + val name = PhpOperators.cloneFunc + val argAst = astForExpr(expr.expr) + val argType = argAst.rootType.orElse(Some(TypeConstants.Any)) + val code = s"$name ${argAst.rootCodeOrEmpty}" - private def astForIsSetExpr(isSetExpr: PhpIsset): Ast = - val name = PhpOperators.issetFunc - val args = isSetExpr.vars.map(astForExpr) - val code = s"$name(${args.map(_.rootCodeOrEmpty).mkString(",")})" + val callNode = newOperatorCallNode(name, code, argType, line(expr), column = column(expr)) + .methodFullName(PhpOperators.cloneFunc) - val callNode = - newOperatorCallNode( - name, - code, - typeFullName = Some(TypeConstants.Bool), - line = line(isSetExpr), - column = column(isSetExpr) - ) - .methodFullName(PhpOperators.issetFunc) + callAst(callNode, argAst :: Nil) - callAst(callNode, args) - private def astForPrintExpr(printExpr: PhpPrint): Ast = - val name = PhpOperators.printFunc - val arg = astForExpr(printExpr.expr) - val code = s"$name(${arg.rootCodeOrEmpty})" + private def astForEmpty(expr: PhpEmptyExpr): Ast = + val name = PhpOperators.emptyFunc + val argAst = astForExpr(expr.expr) + val code = s"$name(${argAst.rootCodeOrEmpty})" - val callNode = - newOperatorCallNode( - name, - code, - typeFullName = Some(TypeConstants.Int), - line = line(printExpr), - column = column(printExpr) - ) - .methodFullName(PhpOperators.printFunc) - - callAst(callNode, arg :: Nil) + val callNode = + newOperatorCallNode( + name, + code, + typeFullName = Some(TypeConstants.Bool), + line = line(expr), + column = column(expr) + ) + .methodFullName(PhpOperators.emptyFunc) - private def astForTernaryOp(ternaryOp: PhpTernaryOp): Ast = - val conditionAst = astForExpr(ternaryOp.condition) - val maybeThenAst = ternaryOp.thenExpr.map(astForExpr) - val elseAst = astForExpr(ternaryOp.elseExpr) + callAst(callNode, argAst :: Nil) - val operatorName = - if maybeThenAst.isDefined then Operators.conditional else PhpOperators.elvisOp - val code = maybeThenAst match - case Some(thenAst) => - s"${conditionAst.rootCodeOrEmpty} ? ${thenAst.rootCodeOrEmpty} : ${elseAst.rootCodeOrEmpty}" - case None => s"${conditionAst.rootCodeOrEmpty} ?: ${elseAst.rootCodeOrEmpty}" + private def astForEval(expr: PhpEvalExpr): Ast = + val name = PhpOperators.evalFunc + val argAst = astForExpr(expr.expr) + val code = s"$name(${argAst.rootCodeOrEmpty})" - val callNode = newOperatorCallNode( - operatorName, + val callNode = + newOperatorCallNode( + name, code, - line = line(ternaryOp), - column = column(ternaryOp) + typeFullName = Some(TypeConstants.Bool), + line = line(expr), + column = column(expr) ) + .methodFullName(PhpOperators.evalFunc) - val args = List(Option(conditionAst), maybeThenAst, Option(elseAst)).flatten - callAst(callNode, args) - end astForTernaryOp - - private def astForThrow(expr: PhpThrowExpr): Ast = - val thrownExpr = astForExpr(expr.expr) - val code = s"throw ${thrownExpr.rootCodeOrEmpty}" + callAst(callNode, argAst :: Nil) - val throwNode = controlStructureNode(expr, ControlStructureTypes.THROW, code) + private def astForExit(expr: PhpExitExpr): Ast = + val name = PhpOperators.exitFunc + val args = expr.expr.map(astForExpr) + val code = s"$name(${args.map(_.rootCodeOrEmpty).getOrElse("")})" - Ast(throwNode).withChild(thrownExpr) + val callNode = newOperatorCallNode( + name, + code, + Some(TypeConstants.Void), + line(expr), + column = column(expr) + ) + .methodFullName(PhpOperators.exitFunc) - private def astForClone(expr: PhpCloneExpr): Ast = - val name = PhpOperators.cloneFunc - val argAst = astForExpr(expr.expr) - val argType = argAst.rootType.orElse(Some(TypeConstants.Any)) - val code = s"$name ${argAst.rootCodeOrEmpty}" + callAst(callNode, args.toList) - val callNode = newOperatorCallNode(name, code, argType, line(expr), column = column(expr)) - .methodFullName(PhpOperators.cloneFunc) + private def getTmpIdentifier( + originNode: PhpNode, + maybeTypeFullName: Option[String], + prefix: String = "" + ): NewIdentifier = + val name = s"$prefix${getNewTmpName()}" + val typeFullName = maybeTypeFullName.getOrElse(TypeConstants.Any) + identifierNode(originNode, name, s"$$$name", typeFullName) - callAst(callNode, argAst :: Nil) + private def astForArrayExpr(expr: PhpArrayExpr): Ast = + val idxTracker = new ArrayIndexTracker - private def astForEmpty(expr: PhpEmptyExpr): Ast = - val name = PhpOperators.emptyFunc - val argAst = astForExpr(expr.expr) - val code = s"$name(${argAst.rootCodeOrEmpty})" + val tmpIdentifier = getTmpIdentifier(expr, Some(TypeConstants.Array)) - val callNode = - newOperatorCallNode( - name, - code, - typeFullName = Some(TypeConstants.Bool), - line = line(expr), - column = column(expr) + val itemAssignments = expr.items.flatMap { + case Some(item) => Option(assignForArrayItem(item, tmpIdentifier.name, idxTracker)) + case None => + idxTracker.next // Skip an index + None + } + val arrayBlock = blockNode(expr) + + Ast(arrayBlock) + .withChildren(itemAssignments) + .withChild(Ast(tmpIdentifier)) + + private def astForListExpr(expr: PhpListExpr): Ast = + /* TODO: Handling list in a way that will actually work with dataflow tracking is somewhat more complicated than + * this and will likely need a fairly ugly lowering. + * + * In short, the case: + * list($a, $b) = $arr; + * can be lowered to: + * $a = $arr[0]; + * $b = $arr[1]; + * + * the case: + * list("id" => $a, "name" => $b) = $arr; + * can be lowered to: + * $a = $arr["id"]; + * $b = $arr["name"]; + * + * and the case: + * foreach ($arr as list($a, $b)) { ... } + * can be lowered as above for each $arr[i]; + * + * The below is just a placeholder to prevent crashes while figuring out the cleanest way to + * implement the above lowering or to think of a better way to do it. + */ + + val name = PhpOperators.listFunc + val args = expr.items.flatten.map { item => astForExpr(item.value) } + val listCode = s"$name(${args.map(_.rootCodeOrEmpty).mkString(",")})" + val listNode = newOperatorCallNode(name, listCode, line = line(expr), column = column(expr)) + .methodFullName(PhpOperators.listFunc) + + callAst(listNode, args) + end astForListExpr + + private def astForNewExpr(expr: PhpNewExpr): Ast = + expr.className match + case classLikeStmt: PhpClassLikeStmt => + astForAnonymousClassInstantiation(expr, classLikeStmt) + + case classNameExpr: PhpExpr => + astForSimpleNewExpr(expr, classNameExpr) + + case other => + throw new NotImplementedError( + s"unexpected expression '$other' of type ${other.getClass}" ) - .methodFullName(PhpOperators.emptyFunc) - - callAst(callNode, argAst :: Nil) - private def astForEval(expr: PhpEvalExpr): Ast = - val name = PhpOperators.evalFunc - val argAst = astForExpr(expr.expr) - val code = s"$name(${argAst.rootCodeOrEmpty})" + private def astForMatchExpr(expr: PhpMatchExpr): Ast = + val conditionAst = astForExpr(expr.condition) - val callNode = - newOperatorCallNode( - name, - code, - typeFullName = Some(TypeConstants.Bool), - line = line(expr), - column = column(expr) - ) - .methodFullName(PhpOperators.evalFunc) + val matchNode = controlStructureNode( + expr, + ControlStructureTypes.MATCH, + s"match (${conditionAst.rootCodeOrEmpty})" + ) - callAst(callNode, argAst :: Nil) + val matchBodyBlock = blockNode(expr) + val armsAsts = expr.matchArms.flatMap(astsForMatchArm) + val matchBody = Ast(matchBodyBlock).withChildren(armsAsts) + + controlStructureAst(matchNode, Option(conditionAst), matchBody :: Nil) + + private def astsForMatchArm(matchArm: PhpMatchArm): List[Ast] = + // TODO Don't just throw away the condition asts here (also for switch cases) + val targets = matchArm.conditions.map { condition => + val conditionAst = astForExpr(condition) + // In PHP cases aren't labeled with `case`, but this is used by the CFG creator to differentiate between + // case/default labels and other labels. + val code = s"case ${conditionAst.rootCode.getOrElse(NameConstants.Unknown)}" + NewJumpTarget().name(code).code(code).lineNumber(line(condition)).columnNumber(column( + condition + )) + } + val defaultLabel = Option.when(matchArm.isDefault)( + NewJumpTarget().name(NameConstants.Default).code(NameConstants.Default).lineNumber(line( + matchArm + )).columnNumber(column(matchArm)) + ) + val targetAsts = (targets ++ defaultLabel.toList).map(Ast(_)) - private def astForExit(expr: PhpExitExpr): Ast = - val name = PhpOperators.exitFunc - val args = expr.expr.map(astForExpr) - val code = s"$name(${args.map(_.rootCodeOrEmpty).getOrElse("")})" + val bodyAst = astForExpr(matchArm.body) - val callNode = newOperatorCallNode( - name, - code, - Some(TypeConstants.Void), - line(expr), - column = column(expr) - ) - .methodFullName(PhpOperators.exitFunc) + targetAsts :+ bodyAst + end astsForMatchArm - callAst(callNode, args.toList) + private def astForYieldExpr(expr: PhpYieldExpr): Ast = + val maybeKey = expr.key.map(astForExpr) + val maybeVal = expr.value.map(astForExpr) - private def getTmpIdentifier( - originNode: PhpNode, - maybeTypeFullName: Option[String], - prefix: String = "" - ): NewIdentifier = - val name = s"$prefix${getNewTmpName()}" - val typeFullName = maybeTypeFullName.getOrElse(TypeConstants.Any) - identifierNode(originNode, name, s"$$$name", typeFullName) + val code = (maybeKey, maybeVal) match + case (Some(key), Some(value)) => + s"yield ${key.rootCodeOrEmpty} => ${value.rootCodeOrEmpty}" - private def astForArrayExpr(expr: PhpArrayExpr): Ast = - val idxTracker = new ArrayIndexTracker + case _ => + s"yield ${maybeKey.map(_.rootCodeOrEmpty) + .getOrElse("")}${maybeVal.map(_.rootCodeOrEmpty).getOrElse("")}".trim - val tmpIdentifier = getTmpIdentifier(expr, Some(TypeConstants.Array)) + val yieldNode = controlStructureNode(expr, ControlStructureTypes.YIELD, code) - val itemAssignments = expr.items.flatMap { - case Some(item) => Option(assignForArrayItem(item, tmpIdentifier.name, idxTracker)) - case None => - idxTracker.next // Skip an index - None - } - val arrayBlock = blockNode(expr) - - Ast(arrayBlock) - .withChildren(itemAssignments) - .withChild(Ast(tmpIdentifier)) - - private def astForListExpr(expr: PhpListExpr): Ast = - /* TODO: Handling list in a way that will actually work with dataflow tracking is somewhat more complicated than - * this and will likely need a fairly ugly lowering. - * - * In short, the case: - * list($a, $b) = $arr; - * can be lowered to: - * $a = $arr[0]; - * $b = $arr[1]; - * - * the case: - * list("id" => $a, "name" => $b) = $arr; - * can be lowered to: - * $a = $arr["id"]; - * $b = $arr["name"]; - * - * and the case: - * foreach ($arr as list($a, $b)) { ... } - * can be lowered as above for each $arr[i]; - * - * The below is just a placeholder to prevent crashes while figuring out the cleanest way to - * implement the above lowering or to think of a better way to do it. - */ - - val name = PhpOperators.listFunc - val args = expr.items.flatten.map { item => astForExpr(item.value) } - val listCode = s"$name(${args.map(_.rootCodeOrEmpty).mkString(",")})" - val listNode = newOperatorCallNode(name, listCode, line = line(expr), column = column(expr)) - .methodFullName(PhpOperators.listFunc) - - callAst(listNode, args) - end astForListExpr - - private def astForNewExpr(expr: PhpNewExpr): Ast = - expr.className match - case classLikeStmt: PhpClassLikeStmt => - astForAnonymousClassInstantiation(expr, classLikeStmt) - - case classNameExpr: PhpExpr => - astForSimpleNewExpr(expr, classNameExpr) - - case other => - throw new NotImplementedError( - s"unexpected expression '$other' of type ${other.getClass}" - ) + Ast(yieldNode) + .withChildren(maybeKey.toList) + .withChildren(maybeVal.toList) - private def astForMatchExpr(expr: PhpMatchExpr): Ast = - val conditionAst = astForExpr(expr.condition) + private def astForClosureExpr(closureExpr: PhpClosureExpr): Ast = + val methodName = scope.getScopedClosureName + val methodRef = methodRefNode(closureExpr, methodName, methodName, TypeConstants.Any) - val matchNode = controlStructureNode( - expr, - ControlStructureTypes.MATCH, - s"match (${conditionAst.rootCodeOrEmpty})" - ) + val localsForUses = closureExpr.uses.flatMap { closureUse => + val variableAst = astForExpr(closureUse.variable) + val codePref = if closureUse.byRef then "&" else "" - val matchBodyBlock = blockNode(expr) - val armsAsts = expr.matchArms.flatMap(astsForMatchArm) - val matchBody = Ast(matchBodyBlock).withChildren(armsAsts) - - controlStructureAst(matchNode, Option(conditionAst), matchBody :: Nil) - - private def astsForMatchArm(matchArm: PhpMatchArm): List[Ast] = - // TODO Don't just throw away the condition asts here (also for switch cases) - val targets = matchArm.conditions.map { condition => - val conditionAst = astForExpr(condition) - // In PHP cases aren't labeled with `case`, but this is used by the CFG creator to differentiate between - // case/default labels and other labels. - val code = s"case ${conditionAst.rootCode.getOrElse(NameConstants.Unknown)}" - NewJumpTarget().name(code).code(code).lineNumber(line(condition)).columnNumber(column( - condition + variableAst.root match + case Some(identifier: NewIdentifier) => + // This is the expected case and is handled well + Some(localNode( + closureExpr, + identifier.name, + codePref ++ identifier.code, + TypeConstants.Any )) - } - val defaultLabel = Option.when(matchArm.isDefault)( - NewJumpTarget().name(NameConstants.Default).code(NameConstants.Default).lineNumber(line( - matchArm - )).columnNumber(column(matchArm)) - ) - val targetAsts = (targets ++ defaultLabel.toList).map(Ast(_)) - - val bodyAst = astForExpr(matchArm.body) - - targetAsts :+ bodyAst - end astsForMatchArm - - private def astForYieldExpr(expr: PhpYieldExpr): Ast = - val maybeKey = expr.key.map(astForExpr) - val maybeVal = expr.value.map(astForExpr) - - val code = (maybeKey, maybeVal) match - case (Some(key), Some(value)) => - s"yield ${key.rootCodeOrEmpty} => ${value.rootCodeOrEmpty}" - - case _ => - s"yield ${maybeKey.map(_.rootCodeOrEmpty).getOrElse("")}${maybeVal.map(_.rootCodeOrEmpty).getOrElse("")}".trim - - val yieldNode = controlStructureNode(expr, ControlStructureTypes.YIELD, code) - - Ast(yieldNode) - .withChildren(maybeKey.toList) - .withChildren(maybeVal.toList) - - private def astForClosureExpr(closureExpr: PhpClosureExpr): Ast = - val methodName = scope.getScopedClosureName - val methodRef = methodRefNode(closureExpr, methodName, methodName, TypeConstants.Any) - - val localsForUses = closureExpr.uses.flatMap { closureUse => - val variableAst = astForExpr(closureUse.variable) - val codePref = if closureUse.byRef then "&" else "" - - variableAst.root match - case Some(identifier: NewIdentifier) => - // This is the expected case and is handled well - Some(localNode( - closureExpr, - identifier.name, - codePref ++ identifier.code, - TypeConstants.Any - )) - case Some(expr: ExpressionNew) => - // Results here may be bad, but its' the best we're likely to do - Some(localNode( - closureExpr, - expr.code, - codePref ++ expr.code, - TypeConstants.Any - )) - case Some(other) => - // This should never happen - logger.debug(s"Found ast '$other' for closure use in $filename") - None - case None => - // This should never happen - logger.debug(s"Found empty ast for closure use in $filename") - None - end match - } - - // Add closure bindings to diffgraph - localsForUses.foreach { local => - val closureBindingId = s"$filename:$methodName:${local.name}" - local.closureBindingId(closureBindingId) - scope.addToScope(local.name, local) - - val closureBindingNode = NewClosureBinding() - .closureBindingId(closureBindingId) - .closureOriginalName(local.name) - .evaluationStrategy(EvaluationStrategies.BY_SHARING) - - // The ref edge to the captured local is added in the ClosureRefPass - diffGraph.addNode(closureBindingNode) - diffGraph.addEdge(methodRef, closureBindingNode, EdgeTypes.CAPTURE) - } - - // Create method for closure - val name = PhpNameExpr(methodName, closureExpr.attributes) - // TODO Check for static modifier - val modifiers = - "LAMBDA" :: (if closureExpr.isStatic then ModifierTypes.STATIC :: Nil else Nil) - val methodDecl = PhpMethodDecl( - name, - closureExpr.params, - modifiers, - closureExpr.returnType, - closureExpr.stmts, - closureExpr.returnByRef, - namespacedName = None, - isClassMethod = closureExpr.isStatic, - closureExpr.attributes - ) - val methodAst = astForMethodDecl(methodDecl, localsForUses.map(Ast(_)), Option(methodName)) - - val usesCode = localsForUses match - case Nil => "" - case locals => s" use(${locals.map(_.code).mkString(", ")})" - methodAst.root.collect { case method: NewMethod => method }.foreach { methodNode => - methodNode.code(methodNode.code ++ usesCode) - } - - // Add method to scope to be attached to typeDecl later - scope.addAnonymousMethod(methodAst) - - Ast(methodRef) - end astForClosureExpr - - private def astForYieldFromExpr(expr: PhpYieldFromExpr): Ast = - // TODO This is currently only distinguishable from yield by the code field. Decide whether to treat YIELD_FROM - // separately or whether to lower this to a foreach with regular yields. - val exprAst = astForExpr(expr.expr) - - val code = s"yield from ${exprAst.rootCodeOrEmpty}" - - val yieldNode = controlStructureNode(expr, ControlStructureTypes.YIELD, code) - - Ast(yieldNode) - .withChild(exprAst) - - private def astForAnonymousClassInstantiation( - expr: PhpNewExpr, - classLikeStmt: PhpClassLikeStmt - ): Ast = - // TODO Do this along with other anonymous class support - Ast() - - private def astForSimpleNewExpr(expr: PhpNewExpr, classNameExpr: PhpExpr): Ast = - val (maybeNameAst, className) = classNameExpr match - case nameExpr: PhpNameExpr => - (None, nameExpr.name) - - case expr: PhpExpr => - val ast = astForExpr(expr) - // The name doesn't make sense in this case, but the AST will be more useful - val name = ast.rootCode.getOrElse(NameConstants.Unknown) - (Option(ast), name) - - val tmpIdentifier = getTmpIdentifier(expr, Option(className)) - - // Alloc assign - val allocCode = s"$className.()" - val allocNode = - newOperatorCallNode( - Operators.alloc, - allocCode, - Option(className), - line(expr), - column = column(expr) - ) - val allocAst = callAst(allocNode, base = maybeNameAst) - val allocAssignCode = s"${tmpIdentifier.code} = ${allocAst.rootCodeOrEmpty}" - val allocAssignNode = newOperatorCallNode( - Operators.assignment, - allocAssignCode, + case Some(expr: ExpressionNew) => + // Results here may be bad, but its' the best we're likely to do + Some(localNode( + closureExpr, + expr.code, + codePref ++ expr.code, + TypeConstants.Any + )) + case Some(other) => + // This should never happen + logger.debug(s"Found ast '$other' for closure use in $filename") + None + case None => + // This should never happen + logger.debug(s"Found empty ast for closure use in $filename") + None + end match + } + + // Add closure bindings to diffgraph + localsForUses.foreach { local => + val closureBindingId = s"$filename:$methodName:${local.name}" + local.closureBindingId(closureBindingId) + scope.addToScope(local.name, local) + + val closureBindingNode = NewClosureBinding() + .closureBindingId(closureBindingId) + .closureOriginalName(local.name) + .evaluationStrategy(EvaluationStrategies.BY_SHARING) + + // The ref edge to the captured local is added in the ClosureRefPass + diffGraph.addNode(closureBindingNode) + diffGraph.addEdge(methodRef, closureBindingNode, EdgeTypes.CAPTURE) + } + + // Create method for closure + val name = PhpNameExpr(methodName, closureExpr.attributes) + // TODO Check for static modifier + val modifiers = + "LAMBDA" :: (if closureExpr.isStatic then ModifierTypes.STATIC :: Nil else Nil) + val methodDecl = PhpMethodDecl( + name, + closureExpr.params, + modifiers, + closureExpr.returnType, + closureExpr.stmts, + closureExpr.returnByRef, + namespacedName = None, + isClassMethod = closureExpr.isStatic, + closureExpr.attributes + ) + val methodAst = astForMethodDecl(methodDecl, localsForUses.map(Ast(_)), Option(methodName)) + + val usesCode = localsForUses match + case Nil => "" + case locals => s" use(${locals.map(_.code).mkString(", ")})" + methodAst.root.collect { case method: NewMethod => method }.foreach { methodNode => + methodNode.code(methodNode.code ++ usesCode) + } + + // Add method to scope to be attached to typeDecl later + scope.addAnonymousMethod(methodAst) + + Ast(methodRef) + end astForClosureExpr + + private def astForYieldFromExpr(expr: PhpYieldFromExpr): Ast = + // TODO This is currently only distinguishable from yield by the code field. Decide whether to treat YIELD_FROM + // separately or whether to lower this to a foreach with regular yields. + val exprAst = astForExpr(expr.expr) + + val code = s"yield from ${exprAst.rootCodeOrEmpty}" + + val yieldNode = controlStructureNode(expr, ControlStructureTypes.YIELD, code) + + Ast(yieldNode) + .withChild(exprAst) + + private def astForAnonymousClassInstantiation( + expr: PhpNewExpr, + classLikeStmt: PhpClassLikeStmt + ): Ast = + // TODO Do this along with other anonymous class support + Ast() + + private def astForSimpleNewExpr(expr: PhpNewExpr, classNameExpr: PhpExpr): Ast = + val (maybeNameAst, className) = classNameExpr match + case nameExpr: PhpNameExpr => + (None, nameExpr.name) + + case expr: PhpExpr => + val ast = astForExpr(expr) + // The name doesn't make sense in this case, but the AST will be more useful + val name = ast.rootCode.getOrElse(NameConstants.Unknown) + (Option(ast), name) + + val tmpIdentifier = getTmpIdentifier(expr, Option(className)) + + // Alloc assign + val allocCode = s"$className.()" + val allocNode = + newOperatorCallNode( + Operators.alloc, + allocCode, Option(className), line(expr), column = column(expr) ) - val allocAssignAst = callAst(allocAssignNode, Ast(tmpIdentifier) :: allocAst :: Nil) - - // Init node - val initArgs = expr.args.map(astForCallArg) - val initSignature = s"$UnresolvedSignature(${initArgs.size})" - val initFullName = s"$className$InstanceMethodDelimiter${ConstructorMethodName}" - val initCode = s"$initFullName(${initArgs.map(_.rootCodeOrEmpty).mkString(",")})" - val initCallNode = callNode( - expr, - initCode, - ConstructorMethodName, - initFullName, - DispatchTypes.DYNAMIC_DISPATCH, - Some(initSignature), - Some(TypeConstants.Any) - ) - val initReceiver = Ast(tmpIdentifier.copy) - val initCallAst = callAst(initCallNode, initArgs, base = Option(initReceiver)) - - // Return identifier - val returnIdentifierAst = Ast(tmpIdentifier.copy) - - Ast(blockNode(expr, "", TypeConstants.Any)) - .withChild(allocAssignAst) - .withChild(initCallAst) - .withChild(returnIdentifierAst) - end astForSimpleNewExpr - - private def dimensionFromSimpleScalar( - scalar: PhpSimpleScalar, - idxTracker: ArrayIndexTracker - ): PhpExpr = - val maybeIntValue = scalar match - case string: PhpString => - string.value - .drop(1) - .dropRight(1) - .toIntOption - - case number => number.value.toIntOption - - maybeIntValue match - case Some(intValue) => - idxTracker.updateValue(intValue) - PhpInt(intValue.toString, scalar.attributes) - - case None => - scalar - end dimensionFromSimpleScalar - private def assignForArrayItem( - item: PhpArrayItem, - name: String, - idxTracker: ArrayIndexTracker - ): Ast = - // It's perhaps a bit clumsy to reconstruct PhpExpr nodes here, but reuse astForArrayDimExpr for consistency - val variable = PhpVariable(PhpNameExpr(name, item.attributes), item.attributes) - - val dimension = item.key match - case Some(key: PhpSimpleScalar) => dimensionFromSimpleScalar(key, idxTracker) - case Some(key) => key - case None => PhpInt(idxTracker.next, item.attributes) - - val dimFetchNode = PhpArrayDimFetchExpr(variable, Option(dimension), item.attributes) - val dimFetchAst = astForArrayDimFetchExpr(dimFetchNode) - - val valueAst = astForArrayItemValue(item) - - val assignCode = s"${dimFetchAst.rootCodeOrEmpty} = ${valueAst.rootCodeOrEmpty}" - - val assignNode = newOperatorCallNode( - Operators.assignment, - assignCode, - line = line(item), - column = column(item) - ) - - callAst(assignNode, dimFetchAst :: valueAst :: Nil) - end assignForArrayItem - - private def astForArrayItemValue(item: PhpArrayItem): Ast = - val exprAst = astForExpr(item.value) - val valueCode = exprAst.rootCodeOrEmpty + val allocAst = callAst(allocNode, base = maybeNameAst) + val allocAssignCode = s"${tmpIdentifier.code} = ${allocAst.rootCodeOrEmpty}" + val allocAssignNode = newOperatorCallNode( + Operators.assignment, + allocAssignCode, + Option(className), + line(expr), + column = column(expr) + ) + val allocAssignAst = callAst(allocAssignNode, Ast(tmpIdentifier) :: allocAst :: Nil) + + // Init node + val initArgs = expr.args.map(astForCallArg) + val initSignature = s"$UnresolvedSignature(${initArgs.size})" + val initFullName = s"$className$InstanceMethodDelimiter${ConstructorMethodName}" + val initCode = s"$initFullName(${initArgs.map(_.rootCodeOrEmpty).mkString(",")})" + val initCallNode = callNode( + expr, + initCode, + ConstructorMethodName, + initFullName, + DispatchTypes.DYNAMIC_DISPATCH, + Some(initSignature), + Some(TypeConstants.Any) + ) + val initReceiver = Ast(tmpIdentifier.copy) + val initCallAst = callAst(initCallNode, initArgs, base = Option(initReceiver)) + + // Return identifier + val returnIdentifierAst = Ast(tmpIdentifier.copy) + + Ast(blockNode(expr, "", TypeConstants.Any)) + .withChild(allocAssignAst) + .withChild(initCallAst) + .withChild(returnIdentifierAst) + end astForSimpleNewExpr + + private def dimensionFromSimpleScalar( + scalar: PhpSimpleScalar, + idxTracker: ArrayIndexTracker + ): PhpExpr = + val maybeIntValue = scalar match + case string: PhpString => + string.value + .drop(1) + .dropRight(1) + .toIntOption + + case number => number.value.toIntOption + + maybeIntValue match + case Some(intValue) => + idxTracker.updateValue(intValue) + PhpInt(intValue.toString, scalar.attributes) + + case None => + scalar + end dimensionFromSimpleScalar + private def assignForArrayItem( + item: PhpArrayItem, + name: String, + idxTracker: ArrayIndexTracker + ): Ast = + // It's perhaps a bit clumsy to reconstruct PhpExpr nodes here, but reuse astForArrayDimExpr for consistency + val variable = PhpVariable(PhpNameExpr(name, item.attributes), item.attributes) + + val dimension = item.key match + case Some(key: PhpSimpleScalar) => dimensionFromSimpleScalar(key, idxTracker) + case Some(key) => key + case None => PhpInt(idxTracker.next, item.attributes) + + val dimFetchNode = PhpArrayDimFetchExpr(variable, Option(dimension), item.attributes) + val dimFetchAst = astForArrayDimFetchExpr(dimFetchNode) + + val valueAst = astForArrayItemValue(item) + + val assignCode = s"${dimFetchAst.rootCodeOrEmpty} = ${valueAst.rootCodeOrEmpty}" + + val assignNode = newOperatorCallNode( + Operators.assignment, + assignCode, + line = line(item), + column = column(item) + ) - if item.byRef then - val parentCall = - newOperatorCallNode( - Operators.addressOf, - s"&$valueCode", - line = line(item), - column = column(item) - ) - callAst(parentCall, exprAst :: Nil) - else if item.unpack then - val parentCall = - newOperatorCallNode( - PhpOperators.unpack, - s"...$valueCode", - line = line(item), - column = column(item) - ) - callAst(parentCall, exprAst :: Nil) - else - exprAst - end if - end astForArrayItemValue - - private def astForArrayDimFetchExpr(expr: PhpArrayDimFetchExpr): Ast = - val variableAst = astForExpr(expr.variable) - val variableCode = variableAst.rootCodeOrEmpty - - expr.dimension match - case Some(dimension) => - val dimensionAst = astForExpr(dimension) - val code = s"$variableCode[${dimensionAst.rootCodeOrEmpty}]" - val accessNode = newOperatorCallNode( - Operators.indexAccess, - code, - line = line(expr), - column = column(expr) - ) - callAst(accessNode, variableAst :: dimensionAst :: Nil) + callAst(assignNode, dimFetchAst :: valueAst :: Nil) + end assignForArrayItem + + private def astForArrayItemValue(item: PhpArrayItem): Ast = + val exprAst = astForExpr(item.value) + val valueCode = exprAst.rootCodeOrEmpty + + if item.byRef then + val parentCall = + newOperatorCallNode( + Operators.addressOf, + s"&$valueCode", + line = line(item), + column = column(item) + ) + callAst(parentCall, exprAst :: Nil) + else if item.unpack then + val parentCall = + newOperatorCallNode( + PhpOperators.unpack, + s"...$valueCode", + line = line(item), + column = column(item) + ) + callAst(parentCall, exprAst :: Nil) + else + exprAst + end if + end astForArrayItemValue + + private def astForArrayDimFetchExpr(expr: PhpArrayDimFetchExpr): Ast = + val variableAst = astForExpr(expr.variable) + val variableCode = variableAst.rootCodeOrEmpty + + expr.dimension match + case Some(dimension) => + val dimensionAst = astForExpr(dimension) + val code = s"$variableCode[${dimensionAst.rootCodeOrEmpty}]" + val accessNode = newOperatorCallNode( + Operators.indexAccess, + code, + line = line(expr), + column = column(expr) + ) + callAst(accessNode, variableAst :: dimensionAst :: Nil) + + case None => + val errorPosition = s"${variableCode}:${line(expr).getOrElse("")}:${filename}" + logger.debug( + s"ArrayDimFetchExpr without dimensions should be handled in assignment: ${errorPosition}" + ) + Ast() + end astForArrayDimFetchExpr + + private def astForErrorSuppressExpr(expr: PhpErrorSuppressExpr): Ast = + val childAst = astForExpr(expr.expr) + + val code = s"@${childAst.rootCodeOrEmpty}" + val suppressNode = newOperatorCallNode( + PhpOperators.errorSuppress, + code, + line = line(expr), + column = column(expr) + ) + childAst.rootType.foreach(typ => suppressNode.typeFullName(typ)) - case None => - val errorPosition = s"${variableCode}:${line(expr).getOrElse("")}:${filename}" - logger.debug( - s"ArrayDimFetchExpr without dimensions should be handled in assignment: ${errorPosition}" - ) - Ast() - end astForArrayDimFetchExpr + callAst(suppressNode, childAst :: Nil) - private def astForErrorSuppressExpr(expr: PhpErrorSuppressExpr): Ast = - val childAst = astForExpr(expr.expr) + private def astForInstanceOfExpr(expr: PhpInstanceOfExpr): Ast = + val exprAst = astForExpr(expr.expr) + val classAst = astForExpr(expr.className) - val code = s"@${childAst.rootCodeOrEmpty}" - val suppressNode = newOperatorCallNode( - PhpOperators.errorSuppress, + val code = s"${exprAst.rootCodeOrEmpty} instanceof ${classAst.rootCodeOrEmpty}" + val instanceOfNode = + newOperatorCallNode( + Operators.instanceOf, code, - line = line(expr), + Some(TypeConstants.Bool), + line(expr), column = column(expr) ) - childAst.rootType.foreach(typ => suppressNode.typeFullName(typ)) - callAst(suppressNode, childAst :: Nil) + callAst(instanceOfNode, exprAst :: classAst :: Nil) - private def astForInstanceOfExpr(expr: PhpInstanceOfExpr): Ast = - val exprAst = astForExpr(expr.expr) - val classAst = astForExpr(expr.className) + private def astForPropertyFetchExpr(expr: PhpPropertyFetchExpr): Ast = + val objExprAst = astForExpr(expr.expr) - val code = s"${exprAst.rootCodeOrEmpty} instanceof ${classAst.rootCodeOrEmpty}" - val instanceOfNode = - newOperatorCallNode( - Operators.instanceOf, - code, - Some(TypeConstants.Bool), - line(expr), - column = column(expr) - ) + val fieldAst = expr.name match + case name: PhpNameExpr => + Ast(newFieldIdentifierNode(name.name, line(expr), column = column(expr))) + case other => astForExpr(other) - callAst(instanceOfNode, exprAst :: classAst :: Nil) + val accessSymbol = + if expr.isStatic then + "::" + else if expr.isNullsafe then + "?->" + else + "->" + + val code = s"${objExprAst.rootCodeOrEmpty}$accessSymbol${fieldAst.rootCodeOrEmpty}" + val fieldAccessNode = newOperatorCallNode( + Operators.fieldAccess, + code, + line = line(expr), + column = column(expr) + ) - private def astForPropertyFetchExpr(expr: PhpPropertyFetchExpr): Ast = - val objExprAst = astForExpr(expr.expr) + callAst(fieldAccessNode, objExprAst :: fieldAst :: Nil) + end astForPropertyFetchExpr - val fieldAst = expr.name match - case name: PhpNameExpr => - Ast(newFieldIdentifierNode(name.name, line(expr), column = column(expr))) - case other => astForExpr(other) + private def astForIncludeExpr(expr: PhpIncludeExpr): Ast = + val exprAst = astForExpr(expr.expr) + val code = s"${expr.includeType} ${exprAst.rootCodeOrEmpty}" + val callNode = + newOperatorCallNode(expr.includeType, code, line = line(expr), column = column(expr)) - val accessSymbol = - if expr.isStatic then - "::" - else if expr.isNullsafe then - "?->" - else - "->" + callAst(callNode, exprAst :: Nil) - val code = s"${objExprAst.rootCodeOrEmpty}$accessSymbol${fieldAst.rootCodeOrEmpty}" - val fieldAccessNode = newOperatorCallNode( - Operators.fieldAccess, - code, - line = line(expr), - column = column(expr) - ) + private def astForShellExecExpr(expr: PhpShellExecExpr): Ast = + val args = astForEncapsed(expr.parts) + val code = "`" + args.rootCodeOrEmpty + "`" - callAst(fieldAccessNode, objExprAst :: fieldAst :: Nil) - end astForPropertyFetchExpr + val callNode = newOperatorCallNode( + PhpOperators.shellExec, + code, + line = line(expr), + column = column(expr) + ) - private def astForIncludeExpr(expr: PhpIncludeExpr): Ast = - val exprAst = astForExpr(expr.expr) - val code = s"${expr.includeType} ${exprAst.rootCodeOrEmpty}" - val callNode = - newOperatorCallNode(expr.includeType, code, line = line(expr), column = column(expr)) + callAst(callNode, args :: Nil) - callAst(callNode, exprAst :: Nil) + private def astForMagicClassConstant(expr: PhpClassConstFetchExpr): Ast = + val classAst = astForExpr(expr.className) + val typeFullName = expr.className match + case nameExpr: PhpNameExpr => + scope + .lookupVariable(nameExpr.name) + .flatMap(_.properties.get(PropertyNames.TYPE_FULL_NAME).map(_.toString)) + .getOrElse(nameExpr.name) - private def astForShellExecExpr(expr: PhpShellExecExpr): Ast = - val args = astForEncapsed(expr.parts) - val code = "`" + args.rootCodeOrEmpty + "`" + case expr => + classAst.rootType.orElse(classAst.rootName).getOrElse(UnresolvedNamespace) - val callNode = newOperatorCallNode( - PhpOperators.shellExec, - code, - line = line(expr), - column = column(expr) - ) + Ast(typeRefNode(expr, classAst.rootCodeOrEmpty, typeFullName)) - callAst(callNode, args :: Nil) - - private def astForMagicClassConstant(expr: PhpClassConstFetchExpr): Ast = - val classAst = astForExpr(expr.className) - val typeFullName = expr.className match - case nameExpr: PhpNameExpr => - scope - .lookupVariable(nameExpr.name) - .flatMap(_.properties.get(PropertyNames.TYPE_FULL_NAME).map(_.toString)) - .getOrElse(nameExpr.name) - - case expr => - classAst.rootType.orElse(classAst.rootName).getOrElse(UnresolvedNamespace) - - Ast(typeRefNode(expr, classAst.rootCodeOrEmpty, typeFullName)) - - private def astForClassConstFetchExpr(expr: PhpClassConstFetchExpr): Ast = - expr.constantName match - // Foo::class should be a TypeRef and not a field access - case Some(constNameExpr) if constNameExpr.name == NameConstants.Class => - astForMagicClassConstant(expr) - - case _ => - val targetAst = astForExpr(expr.className) - val fieldIdentifierName = - expr.constantName.map(_.name).getOrElse(NameConstants.Unknown) - val fieldIdentifier = - newFieldIdentifierNode(fieldIdentifierName, line(expr), column = column(expr)) - val fieldAccessCode = s"${targetAst.rootCodeOrEmpty}::${fieldIdentifier.code}" - val fieldAccessCall = - newOperatorCallNode( - Operators.fieldAccess, - fieldAccessCode, - line = line(expr), - column = column(expr) - ) - callAst(fieldAccessCall, List(targetAst, Ast(fieldIdentifier))) - - private def astForConstFetchExpr(expr: PhpConstFetchExpr): Ast = - val constName = expr.name.name - - if NameConstants.isBoolean(constName) then - Ast(literalNode(expr, constName, TypeConstants.Bool)) - else if NameConstants.isNull(constName) then - Ast(literalNode(expr, constName, TypeConstants.NullType)) - else - val namespaceName = NamespaceTraversal.globalNamespaceName - val identifier = identifierNode(expr, namespaceName, namespaceName, "ANY") - val fieldIdentifier = - newFieldIdentifierNode(constName, line = line(expr), column = column(expr)) + private def astForClassConstFetchExpr(expr: PhpClassConstFetchExpr): Ast = + expr.constantName match + // Foo::class should be a TypeRef and not a field access + case Some(constNameExpr) if constNameExpr.name == NameConstants.Class => + astForMagicClassConstant(expr) - val fieldAccessNode = + case _ => + val targetAst = astForExpr(expr.className) + val fieldIdentifierName = + expr.constantName.map(_.name).getOrElse(NameConstants.Unknown) + val fieldIdentifier = + newFieldIdentifierNode(fieldIdentifierName, line(expr), column = column(expr)) + val fieldAccessCode = s"${targetAst.rootCodeOrEmpty}::${fieldIdentifier.code}" + val fieldAccessCall = newOperatorCallNode( Operators.fieldAccess, - code = constName, + fieldAccessCode, line = line(expr), column = column(expr) ) - val args = List(identifier, fieldIdentifier).map(Ast(_)) - - callAst(fieldAccessNode, args) - end if - end astForConstFetchExpr - - protected def line(phpNode: PhpNode): Option[Integer] = phpNode.attributes.lineNumber - protected def column(phpNode: PhpNode): Option[Integer] = phpNode.attributes.columnNumber - protected def lineEnd(phpNode: PhpNode): Option[Integer] = None - protected def columnEnd(phpNode: PhpNode): Option[Integer] = None - protected def code(phpNode: PhpNode): String = - "" // Sadly, the Php AST does not carry any code fields + callAst(fieldAccessCall, List(targetAst, Ast(fieldIdentifier))) + + private def astForConstFetchExpr(expr: PhpConstFetchExpr): Ast = + val constName = expr.name.name + + if NameConstants.isBoolean(constName) then + Ast(literalNode(expr, constName, TypeConstants.Bool)) + else if NameConstants.isNull(constName) then + Ast(literalNode(expr, constName, TypeConstants.NullType)) + else + val namespaceName = NamespaceTraversal.globalNamespaceName + val identifier = identifierNode(expr, namespaceName, namespaceName, "ANY") + val fieldIdentifier = + newFieldIdentifierNode(constName, line = line(expr), column = column(expr)) + + val fieldAccessNode = + newOperatorCallNode( + Operators.fieldAccess, + code = constName, + line = line(expr), + column = column(expr) + ) + val args = List(identifier, fieldIdentifier).map(Ast(_)) + + callAst(fieldAccessNode, args) + end if + end astForConstFetchExpr + + protected def line(phpNode: PhpNode): Option[Integer] = phpNode.attributes.lineNumber + protected def column(phpNode: PhpNode): Option[Integer] = phpNode.attributes.columnNumber + protected def lineEnd(phpNode: PhpNode): Option[Integer] = None + protected def columnEnd(phpNode: PhpNode): Option[Integer] = None + protected def code(phpNode: PhpNode): String = + "" // Sadly, the Php AST does not carry any code fields end AstCreator object AstCreator: - object TypeConstants: - val String: String = "string" - val Int: String = "int" - val Float: String = "float" - val Bool: String = "bool" - val Void: String = "void" - val Any: String = "ANY" - val Array: String = "array" - val NullType: String = "null" - val VariadicPlaceholder: String = "PhpVariadicPlaceholder" - - object NameConstants: - val Default: String = "default" - val HaltCompiler: String = "__halt_compiler" - val This: String = "this" - val Unknown: String = "UNKNOWN" - val Closure: String = "__closure" - val Class: String = "class" - val True: String = "true"; - val False: String = "false"; - val NullName: String = "null"; - - def isBoolean(name: String): Boolean = - List(True, False).contains(name) - - def isNull(name: String): Boolean = - name.toLowerCase == NullName - - val operatorSymbols: Map[String, String] = Map( - Operators.and -> "&", - Operators.or -> "|", - Operators.xor -> "^", - Operators.logicalAnd -> "&&", - Operators.logicalOr -> "||", - PhpOperators.coalesceOp -> "??", - PhpOperators.concatOp -> ".", - Operators.division -> "/", - Operators.equals -> "==", - Operators.greaterEqualsThan -> ">=", - Operators.greaterThan -> ">", - PhpOperators.identicalOp -> "===", - PhpOperators.logicalXorOp -> "xor", - Operators.minus -> "-", - Operators.modulo -> "%", - Operators.multiplication -> "*", - Operators.notEquals -> "!=", - PhpOperators.notIdenticalOp -> "!==", - Operators.plus -> "+", - Operators.exponentiation -> "**", - Operators.shiftLeft -> "<<", - Operators.arithmeticShiftRight -> ">>", - Operators.lessEqualsThan -> "<=", - Operators.lessThan -> "<", - PhpOperators.spaceshipOp -> "<=>", - Operators.not -> "~", - Operators.logicalNot -> "!", - Operators.postDecrement -> "--", - Operators.postIncrement -> "++", - Operators.preDecrement -> "--", - Operators.preIncrement -> "++", - Operators.minus -> "-", - Operators.plus -> "+", - Operators.assignment -> "=", - Operators.assignmentAnd -> "&=", - Operators.assignmentOr -> "|=", - Operators.assignmentXor -> "^=", - PhpOperators.assignmentCoalesceOp -> "??=", - PhpOperators.assignmentConcatOp -> ".=", - Operators.assignmentDivision -> "/=", - Operators.assignmentMinus -> "-=", - Operators.assignmentModulo -> "%=", - Operators.assignmentMultiplication -> "*=", - Operators.assignmentPlus -> "+=", - Operators.assignmentExponentiation -> "**=", - Operators.assignmentShiftLeft -> "<<=", - Operators.assignmentArithmeticShiftRight -> ">>=" - ) + object TypeConstants: + val String: String = "string" + val Int: String = "int" + val Float: String = "float" + val Bool: String = "bool" + val Void: String = "void" + val Any: String = "ANY" + val Array: String = "array" + val NullType: String = "null" + val VariadicPlaceholder: String = "PhpVariadicPlaceholder" + + object NameConstants: + val Default: String = "default" + val HaltCompiler: String = "__halt_compiler" + val This: String = "this" + val Unknown: String = "UNKNOWN" + val Closure: String = "__closure" + val Class: String = "class" + val True: String = "true"; + val False: String = "false"; + val NullName: String = "null"; + + def isBoolean(name: String): Boolean = + List(True, False).contains(name) + + def isNull(name: String): Boolean = + name.toLowerCase == NullName + + val operatorSymbols: Map[String, String] = Map( + Operators.and -> "&", + Operators.or -> "|", + Operators.xor -> "^", + Operators.logicalAnd -> "&&", + Operators.logicalOr -> "||", + PhpOperators.coalesceOp -> "??", + PhpOperators.concatOp -> ".", + Operators.division -> "/", + Operators.equals -> "==", + Operators.greaterEqualsThan -> ">=", + Operators.greaterThan -> ">", + PhpOperators.identicalOp -> "===", + PhpOperators.logicalXorOp -> "xor", + Operators.minus -> "-", + Operators.modulo -> "%", + Operators.multiplication -> "*", + Operators.notEquals -> "!=", + PhpOperators.notIdenticalOp -> "!==", + Operators.plus -> "+", + Operators.exponentiation -> "**", + Operators.shiftLeft -> "<<", + Operators.arithmeticShiftRight -> ">>", + Operators.lessEqualsThan -> "<=", + Operators.lessThan -> "<", + PhpOperators.spaceshipOp -> "<=>", + Operators.not -> "~", + Operators.logicalNot -> "!", + Operators.postDecrement -> "--", + Operators.postIncrement -> "++", + Operators.preDecrement -> "--", + Operators.preIncrement -> "++", + Operators.minus -> "-", + Operators.plus -> "+", + Operators.assignment -> "=", + Operators.assignmentAnd -> "&=", + Operators.assignmentOr -> "|=", + Operators.assignmentXor -> "^=", + PhpOperators.assignmentCoalesceOp -> "??=", + PhpOperators.assignmentConcatOp -> ".=", + Operators.assignmentDivision -> "/=", + Operators.assignmentMinus -> "-=", + Operators.assignmentModulo -> "%=", + Operators.assignmentMultiplication -> "*=", + Operators.assignmentPlus -> "+=", + Operators.assignmentExponentiation -> "**=", + Operators.assignmentShiftLeft -> "<<=", + Operators.assignmentArithmeticShiftRight -> ">>=" + ) end AstCreator diff --git a/platform/frontends/php2atom/src/main/scala/io/appthreat/php2atom/astcreation/PhpBuiltins.scala b/platform/frontends/php2atom/src/main/scala/io/appthreat/php2atom/astcreation/PhpBuiltins.scala index 52770dcc..be93c103 100644 --- a/platform/frontends/php2atom/src/main/scala/io/appthreat/php2atom/astcreation/PhpBuiltins.scala +++ b/platform/frontends/php2atom/src/main/scala/io/appthreat/php2atom/astcreation/PhpBuiltins.scala @@ -5,5 +5,5 @@ import io.shiftleft.utils.IOUtils import scala.io.Source object PhpBuiltins: - lazy val FuncNames: Set[String] = - Source.fromResource("builtin_functions.txt").getLines().toSet + lazy val FuncNames: Set[String] = + Source.fromResource("builtin_functions.txt").getLines().toSet diff --git a/platform/frontends/php2atom/src/main/scala/io/appthreat/php2atom/parser/Domain.scala b/platform/frontends/php2atom/src/main/scala/io/appthreat/php2atom/parser/Domain.scala index 2434f377..15746052 100644 --- a/platform/frontends/php2atom/src/main/scala/io/appthreat/php2atom/parser/Domain.scala +++ b/platform/frontends/php2atom/src/main/scala/io/appthreat/php2atom/parser/Domain.scala @@ -15,1442 +15,1442 @@ import scala.util.{Success, Try} object Domain: - object PhpOperators: - // TODO Decide which of these should be moved to codepropertygraph - val coalesceOp = ".coalesce" - val concatOp = ".concat" - val identicalOp = ".identical" - val logicalXorOp = ".logicalXor" - val notIdenticalOp = ".notIdentical" - val spaceshipOp = ".spaceship" - val elvisOp = ".elvis" - val unpack = ".unpack" - // Used for $array[] = $var type assignments - val emptyArrayIdx = ".emptyArrayIdx" - val errorSuppress = ".errorSuppress" - // Double arrow operator used to represent key/value pairs: key => value - val doubleArrow = ".doubleArrow" - - val assignmentCoalesceOp = ".assignmentCoalesce" - val assignmentConcatOp = ".assignmentConcat" - - val encaps = "encaps" - val declareFunc = "declare" - val global = "global" - - // These are handled as special cases for builtins since they have separate AST nodes in the PHP-parser output. - val issetFunc = s"isset" - val printFunc = s"print" - val cloneFunc = s"clone" - val emptyFunc = s"empty" - val evalFunc = s"eval" - val exitFunc = s"exit" - // Used for multiple assignments for example `list($a, $b) = $someArray` - val listFunc = s"list" - val isNull = s"is_null" - val unset = s"unset" - val shellExec = s"shell_exec" - end PhpOperators - - object PhpDomainTypeConstants: - val array = "array" - val bool = "bool" - val double = "double" - val int = "int" - val obj = "object" - val string = "string" - val unset = "unset" - - private val logger = LoggerFactory.getLogger(Domain.getClass) - val NamespaceDelimiter = "\\" - val StaticMethodDelimiter = "::" - val InstanceMethodDelimiter = "->" - // Used for creating the default constructor. - val ConstructorMethodName = "__construct" - - final case class PhpAttributes( - lineNumber: Option[Integer], - columnNumber: Option[Integer], - kind: Option[Int] + object PhpOperators: + // TODO Decide which of these should be moved to codepropertygraph + val coalesceOp = ".coalesce" + val concatOp = ".concat" + val identicalOp = ".identical" + val logicalXorOp = ".logicalXor" + val notIdenticalOp = ".notIdentical" + val spaceshipOp = ".spaceship" + val elvisOp = ".elvis" + val unpack = ".unpack" + // Used for $array[] = $var type assignments + val emptyArrayIdx = ".emptyArrayIdx" + val errorSuppress = ".errorSuppress" + // Double arrow operator used to represent key/value pairs: key => value + val doubleArrow = ".doubleArrow" + + val assignmentCoalesceOp = ".assignmentCoalesce" + val assignmentConcatOp = ".assignmentConcat" + + val encaps = "encaps" + val declareFunc = "declare" + val global = "global" + + // These are handled as special cases for builtins since they have separate AST nodes in the PHP-parser output. + val issetFunc = s"isset" + val printFunc = s"print" + val cloneFunc = s"clone" + val emptyFunc = s"empty" + val evalFunc = s"eval" + val exitFunc = s"exit" + // Used for multiple assignments for example `list($a, $b) = $someArray` + val listFunc = s"list" + val isNull = s"is_null" + val unset = s"unset" + val shellExec = s"shell_exec" + end PhpOperators + + object PhpDomainTypeConstants: + val array = "array" + val bool = "bool" + val double = "double" + val int = "int" + val obj = "object" + val string = "string" + val unset = "unset" + + private val logger = LoggerFactory.getLogger(Domain.getClass) + val NamespaceDelimiter = "\\" + val StaticMethodDelimiter = "::" + val InstanceMethodDelimiter = "->" + // Used for creating the default constructor. + val ConstructorMethodName = "__construct" + + final case class PhpAttributes( + lineNumber: Option[Integer], + columnNumber: Option[Integer], + kind: Option[Int] + ) + object PhpAttributes: + val Empty: PhpAttributes = PhpAttributes(None, None, None) + + def apply(json: Value): PhpAttributes = + Try(json("attributes")) match + case Success(Obj(attributes)) => + val startLine = + attributes.get("startLine").map(num => Integer.valueOf(num.num.toInt)) + val startColumn = + attributes.get("startFilePos").map(num => Integer.valueOf(num.num.toInt)) + val kind = attributes.get("kind").map(_.num.toInt) + PhpAttributes(startLine, startColumn, kind) + + case Success(Arr(_)) => + logger.debug(s"Found array attributes in $json") + PhpAttributes.Empty + + case unhandled => + logger.debug(s"Could not find attributes object in type $unhandled") + PhpAttributes.Empty + end PhpAttributes + + object PhpModifiers: + private val ModifierMasks = List( + (1, ModifierTypes.PUBLIC), + (2, ModifierTypes.PROTECTED), + (4, ModifierTypes.PRIVATE), + (8, ModifierTypes.STATIC), + (16, ModifierTypes.ABSTRACT), + (32, ModifierTypes.FINAL), + (64, ModifierTypes.READONLY) ) - object PhpAttributes: - val Empty: PhpAttributes = PhpAttributes(None, None, None) - - def apply(json: Value): PhpAttributes = - Try(json("attributes")) match - case Success(Obj(attributes)) => - val startLine = - attributes.get("startLine").map(num => Integer.valueOf(num.num.toInt)) - val startColumn = - attributes.get("startFilePos").map(num => Integer.valueOf(num.num.toInt)) - val kind = attributes.get("kind").map(_.num.toInt) - PhpAttributes(startLine, startColumn, kind) - - case Success(Arr(_)) => - logger.debug(s"Found array attributes in $json") - PhpAttributes.Empty - - case unhandled => - logger.debug(s"Could not find attributes object in type $unhandled") - PhpAttributes.Empty - end PhpAttributes - - object PhpModifiers: - private val ModifierMasks = List( - (1, ModifierTypes.PUBLIC), - (2, ModifierTypes.PROTECTED), - (4, ModifierTypes.PRIVATE), - (8, ModifierTypes.STATIC), - (16, ModifierTypes.ABSTRACT), - (32, ModifierTypes.FINAL), - (64, ModifierTypes.READONLY) - ) - - private val AccessModifiers: Set[String] = - Set(ModifierTypes.PUBLIC, ModifierTypes.PROTECTED, ModifierTypes.PRIVATE) - - def containsAccessModifier(modifiers: List[String]): Boolean = - modifiers.toSet.intersect(AccessModifiers).nonEmpty - - def getModifierSet(json: Value, modifierString: String = "flags"): List[String] = - val flags = json.objOpt.flatMap(_.get(modifierString)).map(_.num.toInt).getOrElse(0) - ModifierMasks.collect { - case (mask, typ) if (flags & mask) != 0 => typ - } - end PhpModifiers - - sealed trait PhpNode: - def attributes: PhpAttributes - - final case class PhpFile(children: List[PhpStmt]) extends PhpNode: - override val attributes: PhpAttributes = PhpAttributes.Empty - - final case class PhpParam( - name: String, - paramType: Option[PhpNameExpr], - byRef: Boolean, - isVariadic: Boolean, - default: Option[PhpExpr], - // TODO type - flags: Int, - // TODO attributeGroups: Seq[PhpAttributeGroup], - attributes: PhpAttributes - ) extends PhpNode - - sealed trait PhpArgument extends PhpNode - final case class PhpArg( - expr: PhpExpr, - parameterName: Option[String], - byRef: Boolean, - unpack: Boolean, - attributes: PhpAttributes - ) extends PhpArgument - object PhpArg: - def apply(expr: PhpExpr): PhpArg = - PhpArg( - expr, - parameterName = None, - byRef = false, - unpack = false, - attributes = expr.attributes - ) - final case class PhpVariadicPlaceholder(attributes: Domain.PhpAttributes) extends PhpArgument - - sealed trait PhpStmt extends PhpNode - sealed trait PhpStmtWithBody extends PhpStmt: - def stmts: List[PhpStmt] - - // In the PhpParser output, comments are included as an attribute to the first statement following the comment. If - // no such statement exists, a Nop statement (which does not exist in PHP) is added as a sort of comment container. - final case class NopStmt(attributes: PhpAttributes) extends PhpStmt - final case class PhpEchoStmt(exprs: Seq[PhpExpr], attributes: PhpAttributes) extends PhpStmt - final case class PhpBreakStmt(num: Option[Int], attributes: PhpAttributes) extends PhpStmt - final case class PhpContinueStmt(num: Option[Int], attributes: PhpAttributes) extends PhpStmt - final case class PhpWhileStmt(cond: PhpExpr, stmts: List[PhpStmt], attributes: PhpAttributes) - extends PhpStmtWithBody - final case class PhpDoStmt(cond: PhpExpr, stmts: List[PhpStmt], attributes: PhpAttributes) - extends PhpStmtWithBody - final case class PhpForStmt( - inits: List[PhpExpr], - conditions: List[PhpExpr], - loopExprs: List[PhpExpr], - stmts: List[PhpStmt], - attributes: PhpAttributes - ) extends PhpStmtWithBody - final case class PhpIfStmt( - cond: PhpExpr, - stmts: List[PhpStmt], - elseIfs: List[PhpElseIfStmt], - elseStmt: Option[PhpElseStmt], - attributes: PhpAttributes - ) extends PhpStmtWithBody - final case class PhpElseIfStmt(cond: PhpExpr, stmts: List[PhpStmt], attributes: PhpAttributes) - extends PhpStmtWithBody - final case class PhpElseStmt(stmts: List[PhpStmt], attributes: PhpAttributes) - extends PhpStmtWithBody - final case class PhpSwitchStmt( - condition: PhpExpr, - cases: List[PhpCaseStmt], - attributes: PhpAttributes - ) extends PhpStmt - final case class PhpCaseStmt( - condition: Option[PhpExpr], - stmts: List[PhpStmt], - attributes: PhpAttributes - ) extends PhpStmtWithBody - final case class PhpTryStmt( - stmts: List[PhpStmt], - catches: List[PhpCatchStmt], - finallyStmt: Option[PhpFinallyStmt], - attributes: PhpAttributes - ) extends PhpStmtWithBody - final case class PhpCatchStmt( - types: List[PhpNameExpr], - variable: Option[PhpExpr], - stmts: List[PhpStmt], - attributes: PhpAttributes - ) extends PhpStmtWithBody - final case class PhpFinallyStmt(stmts: List[PhpStmt], attributes: PhpAttributes) - extends PhpStmtWithBody - final case class PhpReturnStmt(expr: Option[PhpExpr], attributes: PhpAttributes) extends PhpStmt - - final case class PhpMethodDecl( - name: PhpNameExpr, - params: Seq[PhpParam], - modifiers: List[String], - returnType: Option[PhpNameExpr], - stmts: List[PhpStmt], - returnByRef: Boolean, - // TODO attributeGroups: Seq[PhpAttributeGroup], - namespacedName: Option[PhpNameExpr], - isClassMethod: Boolean, - attributes: PhpAttributes - ) extends PhpStmtWithBody - - final case class PhpClassLikeStmt( - name: Option[PhpNameExpr], - modifiers: List[String], - extendsNames: List[PhpNameExpr], - implementedInterfaces: List[PhpNameExpr], - stmts: List[PhpStmt], - classLikeType: String, - // Optionally used for enums with values - scalarType: Option[PhpNameExpr], - hasConstructor: Boolean, - attributes: PhpAttributes - ) extends PhpStmtWithBody - object ClassLikeTypes: - val Class: String = "class" - val Trait: String = "trait" - val Interface: String = "interface" - val Enum: String = "enum" - - final case class PhpEnumCaseStmt( - name: PhpNameExpr, - expr: Option[PhpExpr], - attributes: PhpAttributes - ) extends PhpStmt - - final case class PhpPropertyStmt( - modifiers: List[String], - variables: List[PhpPropertyValue], - typeName: Option[PhpNameExpr], - attributes: PhpAttributes - ) extends PhpStmt - - final case class PhpPropertyValue( - name: PhpNameExpr, - defaultValue: Option[PhpExpr], - attributes: PhpAttributes - ) extends PhpStmt - - final case class PhpConstStmt( - modifiers: List[String], - consts: List[PhpConstDeclaration], - attributes: PhpAttributes - ) extends PhpStmt - - final case class PhpGotoStmt(label: PhpNameExpr, attributes: PhpAttributes) extends PhpStmt - final case class PhpLabelStmt(label: PhpNameExpr, attributes: PhpAttributes) extends PhpStmt - final case class PhpHaltCompilerStmt(attributes: PhpAttributes) extends PhpStmt - - final case class PhpConstDeclaration( - name: PhpNameExpr, - value: PhpExpr, - namespacedName: Option[PhpNameExpr], - attributes: PhpAttributes - ) extends PhpStmt - - final case class PhpNamespaceStmt( - name: Option[PhpNameExpr], - stmts: List[PhpStmt], - attributes: PhpAttributes - ) extends PhpStmtWithBody - - final case class PhpDeclareStmt( - declares: Seq[PhpDeclareItem], - stmts: Option[List[PhpStmt]], - attributes: PhpAttributes - ) extends PhpStmt - final case class PhpDeclareItem(key: PhpNameExpr, value: PhpExpr, attributes: PhpAttributes) - extends PhpStmt - - final case class PhpUnsetStmt(vars: List[PhpExpr], attributes: PhpAttributes) extends PhpStmt - - final case class PhpStaticStmt(vars: List[PhpStaticVar], attributes: PhpAttributes) - extends PhpStmt - - final case class PhpStaticVar( - variable: PhpVariable, - defaultValue: Option[PhpExpr], - attributes: PhpAttributes - ) extends PhpStmt - - final case class PhpGlobalStmt(vars: List[PhpExpr], attributes: PhpAttributes) extends PhpStmt - - final case class PhpUseStmt( - uses: List[PhpUseUse], - useType: PhpUseType, - attributes: PhpAttributes - ) extends PhpStmt - final case class PhpGroupUseStmt( - prefix: PhpNameExpr, - uses: List[PhpUseUse], - useType: PhpUseType, - attributes: PhpAttributes - ) extends PhpStmt - final case class PhpUseUse( - originalName: PhpNameExpr, - alias: Option[PhpNameExpr], - useType: PhpUseType, - attributes: PhpAttributes - ) extends PhpStmt - - case object PhpUseType: - sealed trait PhpUseType - case object Unknown extends PhpUseType - case object Normal extends PhpUseType - case object Function extends PhpUseType - case object Constant extends PhpUseType - - def getUseType(typeNum: Int): PhpUseType = - typeNum match - case 1 => Normal - case 2 => Function - case 3 => Constant - case _ => Unknown - - final case class PhpForeachStmt( - iterExpr: PhpExpr, - keyVar: Option[PhpExpr], - valueVar: PhpExpr, - assignByRef: Boolean, - stmts: List[PhpStmt], - attributes: PhpAttributes - ) extends PhpStmtWithBody - final case class PhpTraitUseStmt( - traits: List[PhpNameExpr], - adaptations: List[PhpTraitUseAdaptation], - attributes: PhpAttributes - ) extends PhpStmt - sealed trait PhpTraitUseAdaptation extends PhpStmt - final case class PhpPrecedenceAdaptation( - traitName: PhpNameExpr, - methodName: PhpNameExpr, - insteadOf: List[PhpNameExpr], - attributes: PhpAttributes - ) extends PhpTraitUseAdaptation - final case class PhpAliasAdaptation( - traitName: Option[PhpNameExpr], - methodName: PhpNameExpr, - newModifier: Option[String], - newName: Option[PhpNameExpr], - attributes: PhpAttributes - ) extends PhpTraitUseAdaptation - - sealed trait PhpExpr extends PhpStmt - - final case class PhpNewExpr( - className: PhpNode, - args: List[PhpArgument], - attributes: PhpAttributes - ) extends PhpExpr - - final case class PhpIncludeExpr(expr: PhpExpr, includeType: String, attributes: PhpAttributes) - extends PhpExpr - case object PhpIncludeType: - val Include: String = "include" - val IncludeOnce: String = "include_once" - val Require: String = "require" - val RequireOnce: String = "require_once" - - final case class PhpCallExpr( - target: Option[PhpExpr], - methodName: PhpExpr, - args: Seq[PhpArgument], - isNullSafe: Boolean, - isStatic: Boolean, - attributes: PhpAttributes - ) extends PhpExpr - final case class PhpVariable(value: PhpExpr, attributes: PhpAttributes) extends PhpExpr - final case class PhpNameExpr(name: String, attributes: PhpAttributes) extends PhpExpr - final case class PhpCloneExpr(expr: PhpExpr, attributes: PhpAttributes) extends PhpExpr - final case class PhpEmptyExpr(expr: PhpExpr, attributes: PhpAttributes) extends PhpExpr - final case class PhpEvalExpr(expr: PhpExpr, attributes: PhpAttributes) extends PhpExpr - final case class PhpExitExpr(expr: Option[PhpExpr], attributes: PhpAttributes) extends PhpExpr - final case class PhpBinaryOp( - operator: String, - left: PhpExpr, - right: PhpExpr, - attributes: PhpAttributes - ) extends PhpExpr - object PhpBinaryOp: - val BinaryOpTypeMap: Map[String, String] = Map( - "Expr_BinaryOp_BitwiseAnd" -> Operators.and, - "Expr_BinaryOp_BitwiseOr" -> Operators.or, - "Expr_BinaryOp_BitwiseXor" -> Operators.xor, - "Expr_BinaryOp_BooleanAnd" -> Operators.logicalAnd, - "Expr_BinaryOp_BooleanOr" -> Operators.logicalOr, - "Expr_BinaryOp_Coalesce" -> PhpOperators.coalesceOp, - "Expr_BinaryOp_Concat" -> PhpOperators.concatOp, - "Expr_BinaryOp_Div" -> Operators.division, - "Expr_BinaryOp_Equal" -> Operators.equals, - "Expr_BinaryOp_GreaterOrEqual" -> Operators.greaterEqualsThan, - "Expr_BinaryOp_Greater" -> Operators.greaterThan, - "Expr_BinaryOp_Identical" -> PhpOperators.identicalOp, - "Expr_BinaryOp_LogicalAnd" -> Operators.logicalAnd, - "Expr_BinaryOp_LogicalOr" -> Operators.logicalOr, - "Expr_BinaryOp_LogicalXor" -> PhpOperators.logicalXorOp, - "Expr_BinaryOp_Minus" -> Operators.minus, - "Expr_BinaryOp_Mod" -> Operators.modulo, - "Expr_BinaryOp_Mul" -> Operators.multiplication, - "Expr_BinaryOp_NotEqual" -> Operators.notEquals, - "Expr_BinaryOp_NotIdentical" -> PhpOperators.notIdenticalOp, - "Expr_BinaryOp_Plus" -> Operators.plus, - "Expr_BinaryOp_Pow" -> Operators.exponentiation, - "Expr_BinaryOp_ShiftLeft" -> Operators.shiftLeft, - "Expr_BinaryOp_ShiftRight" -> Operators.arithmeticShiftRight, - "Expr_BinaryOp_SmallerOrEqual" -> Operators.lessEqualsThan, - "Expr_BinaryOp_Smaller" -> Operators.lessThan, - "Expr_BinaryOp_Spaceship" -> PhpOperators.spaceshipOp - ) - def isBinaryOpType(typeName: String): Boolean = - BinaryOpTypeMap.contains(typeName) - end PhpBinaryOp - final case class PhpUnaryOp(operator: String, expr: PhpExpr, attributes: PhpAttributes) - extends PhpExpr - object PhpUnaryOp: - val UnaryOpTypeMap: Map[String, String] = Map( - "Expr_BitwiseNot" -> Operators.not, - "Expr_BooleanNot" -> Operators.logicalNot, - "Expr_PostDec" -> Operators.postDecrement, - "Expr_PostInc" -> Operators.postIncrement, - "Expr_PreDec" -> Operators.preDecrement, - "Expr_PreInc" -> Operators.preIncrement, - "Expr_UnaryMinus" -> Operators.minus, - "Expr_UnaryPlus" -> Operators.plus + private val AccessModifiers: Set[String] = + Set(ModifierTypes.PUBLIC, ModifierTypes.PROTECTED, ModifierTypes.PRIVATE) + + def containsAccessModifier(modifiers: List[String]): Boolean = + modifiers.toSet.intersect(AccessModifiers).nonEmpty + + def getModifierSet(json: Value, modifierString: String = "flags"): List[String] = + val flags = json.objOpt.flatMap(_.get(modifierString)).map(_.num.toInt).getOrElse(0) + ModifierMasks.collect { + case (mask, typ) if (flags & mask) != 0 => typ + } + end PhpModifiers + + sealed trait PhpNode: + def attributes: PhpAttributes + + final case class PhpFile(children: List[PhpStmt]) extends PhpNode: + override val attributes: PhpAttributes = PhpAttributes.Empty + + final case class PhpParam( + name: String, + paramType: Option[PhpNameExpr], + byRef: Boolean, + isVariadic: Boolean, + default: Option[PhpExpr], + // TODO type + flags: Int, + // TODO attributeGroups: Seq[PhpAttributeGroup], + attributes: PhpAttributes + ) extends PhpNode + + sealed trait PhpArgument extends PhpNode + final case class PhpArg( + expr: PhpExpr, + parameterName: Option[String], + byRef: Boolean, + unpack: Boolean, + attributes: PhpAttributes + ) extends PhpArgument + object PhpArg: + def apply(expr: PhpExpr): PhpArg = + PhpArg( + expr, + parameterName = None, + byRef = false, + unpack = false, + attributes = expr.attributes ) + final case class PhpVariadicPlaceholder(attributes: Domain.PhpAttributes) extends PhpArgument + + sealed trait PhpStmt extends PhpNode + sealed trait PhpStmtWithBody extends PhpStmt: + def stmts: List[PhpStmt] + + // In the PhpParser output, comments are included as an attribute to the first statement following the comment. If + // no such statement exists, a Nop statement (which does not exist in PHP) is added as a sort of comment container. + final case class NopStmt(attributes: PhpAttributes) extends PhpStmt + final case class PhpEchoStmt(exprs: Seq[PhpExpr], attributes: PhpAttributes) extends PhpStmt + final case class PhpBreakStmt(num: Option[Int], attributes: PhpAttributes) extends PhpStmt + final case class PhpContinueStmt(num: Option[Int], attributes: PhpAttributes) extends PhpStmt + final case class PhpWhileStmt(cond: PhpExpr, stmts: List[PhpStmt], attributes: PhpAttributes) + extends PhpStmtWithBody + final case class PhpDoStmt(cond: PhpExpr, stmts: List[PhpStmt], attributes: PhpAttributes) + extends PhpStmtWithBody + final case class PhpForStmt( + inits: List[PhpExpr], + conditions: List[PhpExpr], + loopExprs: List[PhpExpr], + stmts: List[PhpStmt], + attributes: PhpAttributes + ) extends PhpStmtWithBody + final case class PhpIfStmt( + cond: PhpExpr, + stmts: List[PhpStmt], + elseIfs: List[PhpElseIfStmt], + elseStmt: Option[PhpElseStmt], + attributes: PhpAttributes + ) extends PhpStmtWithBody + final case class PhpElseIfStmt(cond: PhpExpr, stmts: List[PhpStmt], attributes: PhpAttributes) + extends PhpStmtWithBody + final case class PhpElseStmt(stmts: List[PhpStmt], attributes: PhpAttributes) + extends PhpStmtWithBody + final case class PhpSwitchStmt( + condition: PhpExpr, + cases: List[PhpCaseStmt], + attributes: PhpAttributes + ) extends PhpStmt + final case class PhpCaseStmt( + condition: Option[PhpExpr], + stmts: List[PhpStmt], + attributes: PhpAttributes + ) extends PhpStmtWithBody + final case class PhpTryStmt( + stmts: List[PhpStmt], + catches: List[PhpCatchStmt], + finallyStmt: Option[PhpFinallyStmt], + attributes: PhpAttributes + ) extends PhpStmtWithBody + final case class PhpCatchStmt( + types: List[PhpNameExpr], + variable: Option[PhpExpr], + stmts: List[PhpStmt], + attributes: PhpAttributes + ) extends PhpStmtWithBody + final case class PhpFinallyStmt(stmts: List[PhpStmt], attributes: PhpAttributes) + extends PhpStmtWithBody + final case class PhpReturnStmt(expr: Option[PhpExpr], attributes: PhpAttributes) extends PhpStmt + + final case class PhpMethodDecl( + name: PhpNameExpr, + params: Seq[PhpParam], + modifiers: List[String], + returnType: Option[PhpNameExpr], + stmts: List[PhpStmt], + returnByRef: Boolean, + // TODO attributeGroups: Seq[PhpAttributeGroup], + namespacedName: Option[PhpNameExpr], + isClassMethod: Boolean, + attributes: PhpAttributes + ) extends PhpStmtWithBody + + final case class PhpClassLikeStmt( + name: Option[PhpNameExpr], + modifiers: List[String], + extendsNames: List[PhpNameExpr], + implementedInterfaces: List[PhpNameExpr], + stmts: List[PhpStmt], + classLikeType: String, + // Optionally used for enums with values + scalarType: Option[PhpNameExpr], + hasConstructor: Boolean, + attributes: PhpAttributes + ) extends PhpStmtWithBody + object ClassLikeTypes: + val Class: String = "class" + val Trait: String = "trait" + val Interface: String = "interface" + val Enum: String = "enum" + + final case class PhpEnumCaseStmt( + name: PhpNameExpr, + expr: Option[PhpExpr], + attributes: PhpAttributes + ) extends PhpStmt + + final case class PhpPropertyStmt( + modifiers: List[String], + variables: List[PhpPropertyValue], + typeName: Option[PhpNameExpr], + attributes: PhpAttributes + ) extends PhpStmt + + final case class PhpPropertyValue( + name: PhpNameExpr, + defaultValue: Option[PhpExpr], + attributes: PhpAttributes + ) extends PhpStmt + + final case class PhpConstStmt( + modifiers: List[String], + consts: List[PhpConstDeclaration], + attributes: PhpAttributes + ) extends PhpStmt + + final case class PhpGotoStmt(label: PhpNameExpr, attributes: PhpAttributes) extends PhpStmt + final case class PhpLabelStmt(label: PhpNameExpr, attributes: PhpAttributes) extends PhpStmt + final case class PhpHaltCompilerStmt(attributes: PhpAttributes) extends PhpStmt + + final case class PhpConstDeclaration( + name: PhpNameExpr, + value: PhpExpr, + namespacedName: Option[PhpNameExpr], + attributes: PhpAttributes + ) extends PhpStmt + + final case class PhpNamespaceStmt( + name: Option[PhpNameExpr], + stmts: List[PhpStmt], + attributes: PhpAttributes + ) extends PhpStmtWithBody + + final case class PhpDeclareStmt( + declares: Seq[PhpDeclareItem], + stmts: Option[List[PhpStmt]], + attributes: PhpAttributes + ) extends PhpStmt + final case class PhpDeclareItem(key: PhpNameExpr, value: PhpExpr, attributes: PhpAttributes) + extends PhpStmt + + final case class PhpUnsetStmt(vars: List[PhpExpr], attributes: PhpAttributes) extends PhpStmt + + final case class PhpStaticStmt(vars: List[PhpStaticVar], attributes: PhpAttributes) + extends PhpStmt + + final case class PhpStaticVar( + variable: PhpVariable, + defaultValue: Option[PhpExpr], + attributes: PhpAttributes + ) extends PhpStmt + + final case class PhpGlobalStmt(vars: List[PhpExpr], attributes: PhpAttributes) extends PhpStmt + + final case class PhpUseStmt( + uses: List[PhpUseUse], + useType: PhpUseType, + attributes: PhpAttributes + ) extends PhpStmt + final case class PhpGroupUseStmt( + prefix: PhpNameExpr, + uses: List[PhpUseUse], + useType: PhpUseType, + attributes: PhpAttributes + ) extends PhpStmt + final case class PhpUseUse( + originalName: PhpNameExpr, + alias: Option[PhpNameExpr], + useType: PhpUseType, + attributes: PhpAttributes + ) extends PhpStmt + + case object PhpUseType: + sealed trait PhpUseType + case object Unknown extends PhpUseType + case object Normal extends PhpUseType + case object Function extends PhpUseType + case object Constant extends PhpUseType + + def getUseType(typeNum: Int): PhpUseType = + typeNum match + case 1 => Normal + case 2 => Function + case 3 => Constant + case _ => Unknown + + final case class PhpForeachStmt( + iterExpr: PhpExpr, + keyVar: Option[PhpExpr], + valueVar: PhpExpr, + assignByRef: Boolean, + stmts: List[PhpStmt], + attributes: PhpAttributes + ) extends PhpStmtWithBody + final case class PhpTraitUseStmt( + traits: List[PhpNameExpr], + adaptations: List[PhpTraitUseAdaptation], + attributes: PhpAttributes + ) extends PhpStmt + sealed trait PhpTraitUseAdaptation extends PhpStmt + final case class PhpPrecedenceAdaptation( + traitName: PhpNameExpr, + methodName: PhpNameExpr, + insteadOf: List[PhpNameExpr], + attributes: PhpAttributes + ) extends PhpTraitUseAdaptation + final case class PhpAliasAdaptation( + traitName: Option[PhpNameExpr], + methodName: PhpNameExpr, + newModifier: Option[String], + newName: Option[PhpNameExpr], + attributes: PhpAttributes + ) extends PhpTraitUseAdaptation + + sealed trait PhpExpr extends PhpStmt + + final case class PhpNewExpr( + className: PhpNode, + args: List[PhpArgument], + attributes: PhpAttributes + ) extends PhpExpr + + final case class PhpIncludeExpr(expr: PhpExpr, includeType: String, attributes: PhpAttributes) + extends PhpExpr + case object PhpIncludeType: + val Include: String = "include" + val IncludeOnce: String = "include_once" + val Require: String = "require" + val RequireOnce: String = "require_once" + + final case class PhpCallExpr( + target: Option[PhpExpr], + methodName: PhpExpr, + args: Seq[PhpArgument], + isNullSafe: Boolean, + isStatic: Boolean, + attributes: PhpAttributes + ) extends PhpExpr + final case class PhpVariable(value: PhpExpr, attributes: PhpAttributes) extends PhpExpr + final case class PhpNameExpr(name: String, attributes: PhpAttributes) extends PhpExpr + final case class PhpCloneExpr(expr: PhpExpr, attributes: PhpAttributes) extends PhpExpr + final case class PhpEmptyExpr(expr: PhpExpr, attributes: PhpAttributes) extends PhpExpr + final case class PhpEvalExpr(expr: PhpExpr, attributes: PhpAttributes) extends PhpExpr + final case class PhpExitExpr(expr: Option[PhpExpr], attributes: PhpAttributes) extends PhpExpr + final case class PhpBinaryOp( + operator: String, + left: PhpExpr, + right: PhpExpr, + attributes: PhpAttributes + ) extends PhpExpr + object PhpBinaryOp: + val BinaryOpTypeMap: Map[String, String] = Map( + "Expr_BinaryOp_BitwiseAnd" -> Operators.and, + "Expr_BinaryOp_BitwiseOr" -> Operators.or, + "Expr_BinaryOp_BitwiseXor" -> Operators.xor, + "Expr_BinaryOp_BooleanAnd" -> Operators.logicalAnd, + "Expr_BinaryOp_BooleanOr" -> Operators.logicalOr, + "Expr_BinaryOp_Coalesce" -> PhpOperators.coalesceOp, + "Expr_BinaryOp_Concat" -> PhpOperators.concatOp, + "Expr_BinaryOp_Div" -> Operators.division, + "Expr_BinaryOp_Equal" -> Operators.equals, + "Expr_BinaryOp_GreaterOrEqual" -> Operators.greaterEqualsThan, + "Expr_BinaryOp_Greater" -> Operators.greaterThan, + "Expr_BinaryOp_Identical" -> PhpOperators.identicalOp, + "Expr_BinaryOp_LogicalAnd" -> Operators.logicalAnd, + "Expr_BinaryOp_LogicalOr" -> Operators.logicalOr, + "Expr_BinaryOp_LogicalXor" -> PhpOperators.logicalXorOp, + "Expr_BinaryOp_Minus" -> Operators.minus, + "Expr_BinaryOp_Mod" -> Operators.modulo, + "Expr_BinaryOp_Mul" -> Operators.multiplication, + "Expr_BinaryOp_NotEqual" -> Operators.notEquals, + "Expr_BinaryOp_NotIdentical" -> PhpOperators.notIdenticalOp, + "Expr_BinaryOp_Plus" -> Operators.plus, + "Expr_BinaryOp_Pow" -> Operators.exponentiation, + "Expr_BinaryOp_ShiftLeft" -> Operators.shiftLeft, + "Expr_BinaryOp_ShiftRight" -> Operators.arithmeticShiftRight, + "Expr_BinaryOp_SmallerOrEqual" -> Operators.lessEqualsThan, + "Expr_BinaryOp_Smaller" -> Operators.lessThan, + "Expr_BinaryOp_Spaceship" -> PhpOperators.spaceshipOp + ) - def isUnaryOpType(typeName: String): Boolean = - UnaryOpTypeMap.contains(typeName) - final case class PhpTernaryOp( - condition: PhpExpr, - thenExpr: Option[PhpExpr], - elseExpr: PhpExpr, - attributes: PhpAttributes - ) extends PhpExpr - - object PhpAssignment: - val AssignTypeMap: Map[String, String] = Map( - "Expr_Assign" -> Operators.assignment, - "Expr_AssignRef" -> Operators.assignment, - "Expr_AssignOp_BitwiseAnd" -> Operators.assignmentAnd, - "Expr_AssignOp_BitwiseOr" -> Operators.assignmentOr, - "Expr_AssignOp_BitwiseXor" -> Operators.assignmentXor, - "Expr_AssignOp_Coalesce" -> PhpOperators.assignmentCoalesceOp, - "Expr_AssignOp_Concat" -> PhpOperators.assignmentConcatOp, - "Expr_AssignOp_Div" -> Operators.assignmentDivision, - "Expr_AssignOp_Minus" -> Operators.assignmentMinus, - "Expr_AssignOp_Mod" -> Operators.assignmentModulo, - "Expr_AssignOp_Mul" -> Operators.assignmentMultiplication, - "Expr_AssignOp_Plus" -> Operators.assignmentPlus, - "Expr_AssignOp_Pow" -> Operators.assignmentExponentiation, - "Expr_AssignOp_ShiftLeft" -> Operators.assignmentShiftLeft, - "Expr_AssignOp_ShiftRight" -> Operators.assignmentArithmeticShiftRight - ) - - def isAssignType(typeName: String): Boolean = - AssignTypeMap.contains(typeName) - end PhpAssignment - final case class PhpAssignment( - assignOp: String, - target: PhpExpr, - source: PhpExpr, - isRefAssign: Boolean, - attributes: PhpAttributes - ) extends PhpExpr - - final case class PhpCast(typ: String, expr: PhpExpr, attributes: PhpAttributes) extends PhpExpr - object PhpCast: - val CastTypeMap: Map[String, String] = Map( - "Expr_Cast_Array" -> PhpDomainTypeConstants.array, - "Expr_Cast_Bool" -> PhpDomainTypeConstants.bool, - "Expr_Cast_Double" -> PhpDomainTypeConstants.double, - "Expr_Cast_Int" -> PhpDomainTypeConstants.int, - "Expr_Cast_Object" -> PhpDomainTypeConstants.obj, - "Expr_Cast_String" -> PhpDomainTypeConstants.string, - "Expr_Cast_Unset" -> PhpDomainTypeConstants.unset - ) - - def isCastType(typeName: String): Boolean = - CastTypeMap.contains(typeName) - - final case class PhpIsset(vars: Seq[PhpExpr], attributes: PhpAttributes) extends PhpExpr - final case class PhpPrint(expr: PhpExpr, attributes: PhpAttributes) extends PhpExpr - - sealed trait PhpScalar extends PhpExpr - sealed abstract class PhpSimpleScalar(val typeFullName: String) extends PhpScalar: - def value: String - def attributes: PhpAttributes - - final case class PhpString(val value: String, val attributes: PhpAttributes) - extends PhpSimpleScalar(TypeConstants.String) - object PhpString: - def withQuotes(value: String, attributes: PhpAttributes): PhpString = - PhpString(s"\"${escapeString(value)}\"", attributes) - - final case class PhpInt(val value: String, val attributes: PhpAttributes) - extends PhpSimpleScalar(TypeConstants.Int) - - final case class PhpFloat(val value: String, val attributes: PhpAttributes) - extends PhpSimpleScalar(TypeConstants.Float) - - final case class PhpEncapsed(parts: Seq[PhpExpr], attributes: PhpAttributes) extends PhpScalar - - final case class PhpThrowExpr(expr: PhpExpr, attributes: PhpAttributes) extends PhpExpr - final case class PhpListExpr(items: List[Option[PhpArrayItem]], attributes: PhpAttributes) - extends PhpExpr - - final case class PhpClassConstFetchExpr( - className: PhpExpr, - constantName: Option[PhpNameExpr], - attributes: PhpAttributes - ) extends PhpExpr - - final case class PhpConstFetchExpr(name: PhpNameExpr, attributes: PhpAttributes) extends PhpExpr - - final case class PhpArrayExpr(items: List[Option[PhpArrayItem]], attributes: PhpAttributes) - extends PhpExpr - final case class PhpArrayItem( - key: Option[PhpExpr], - value: PhpExpr, - byRef: Boolean, - unpack: Boolean, - attributes: PhpAttributes - ) extends PhpExpr - final case class PhpArrayDimFetchExpr( - variable: PhpExpr, - dimension: Option[PhpExpr], - attributes: PhpAttributes - ) extends PhpExpr - - final case class PhpErrorSuppressExpr(expr: PhpExpr, attributes: PhpAttributes) extends PhpExpr - - final case class PhpInstanceOfExpr(expr: PhpExpr, className: PhpExpr, attributes: PhpAttributes) - extends PhpExpr - - final case class PhpShellExecExpr(parts: PhpEncapsed, attributes: PhpAttributes) extends PhpExpr - - final case class PhpPropertyFetchExpr( - expr: PhpExpr, - name: PhpExpr, - isNullsafe: Boolean, - isStatic: Boolean, - attributes: PhpAttributes - ) extends PhpExpr - - final case class PhpMatchExpr( - condition: PhpExpr, - matchArms: List[PhpMatchArm], - attributes: PhpAttributes - ) extends PhpExpr - - final case class PhpMatchArm( - conditions: List[PhpExpr], - body: PhpExpr, - isDefault: Boolean, - attributes: PhpAttributes - ) extends PhpExpr - - final case class PhpYieldExpr( - key: Option[PhpExpr], - value: Option[PhpExpr], - attributes: PhpAttributes - ) extends PhpExpr - final case class PhpYieldFromExpr(expr: PhpExpr, attributes: PhpAttributes) extends PhpExpr - - final case class PhpClosureExpr( - params: List[PhpParam], - stmts: List[PhpStmt], - returnType: Option[PhpNameExpr], - uses: List[PhpClosureUse], - isStatic: Boolean, - returnByRef: Boolean, - isArrowFunc: Boolean, - attributes: PhpAttributes - ) extends PhpExpr - final case class PhpClosureUse(variable: PhpExpr, byRef: Boolean, attributes: PhpAttributes) - extends PhpExpr - - private def escapeString(value: String): String = - value - .replace("\\", "\\\\") - .replace("\n", "\\n") - .replace("\b", "\\b") - .replace("\r", "\\r") - .replace("\t", "\\t") - .replace("\'", "\\'") - .replace("\f", "\\f") - .replace("\"", "\\\"") - - private def readFile(json: Value): PhpFile = - json match - case arr: Arr => - val children = arr.value.map(readStmt).toList - PhpFile(children) - case unhandled => - logger.debug( - s"Found unhandled type in readFile: ${unhandled.getClass} with value $unhandled" - ) - ??? - - private def readStmt(json: Value): PhpStmt = - json("nodeType").str match - case "Stmt_Echo" => - val values = json("exprs").arr.map(readExpr).toSeq - PhpEchoStmt(values, PhpAttributes(json)) - case "Stmt_Expression" => readExpr(json("expr")) - case "Stmt_Function" => readFunction(json) - case "Stmt_InlineHTML" => readInlineHtml(json) - case "Stmt_Break" => readBreak(json) - case "Stmt_Continue" => readContinue(json) - case "Stmt_While" => readWhile(json) - case "Stmt_Do" => readDo(json) - case "Stmt_For" => readFor(json) - case "Stmt_If" => readIf(json) - case "Stmt_Switch" => readSwitch(json) - case "Stmt_TryCatch" => readTry(json) - case "Stmt_Throw" => readThrow(json) - case "Stmt_Return" => readReturn(json) - case "Stmt_Class" => readClassLike(json, ClassLikeTypes.Class) - case "Stmt_Interface" => readClassLike(json, ClassLikeTypes.Interface) - case "Stmt_Trait" => readClassLike(json, ClassLikeTypes.Trait) - case "Stmt_Enum" => readClassLike(json, ClassLikeTypes.Enum) - case "Stmt_EnumCase" => readEnumCase(json) - case "Stmt_ClassMethod" => readClassMethod(json) - case "Stmt_Property" => readProperty(json) - case "Stmt_ClassConst" => readConst(json) - case "Stmt_Const" => readConst(json) - case "Stmt_Goto" => readGoto(json) - case "Stmt_Label" => readLabel(json) - case "Stmt_HaltCompiler" => readHaltCompiler(json) - case "Stmt_Namespace" => readNamespace(json) - case "Stmt_Nop" => NopStmt(PhpAttributes(json)) - case "Stmt_Declare" => readDeclare(json) - case "Stmt_Unset" => readUnset(json) - case "Stmt_Static" => readStatic(json) - case "Stmt_Global" => readGlobal(json) - case "Stmt_Use" => readUse(json) - case "Stmt_GroupUse" => readGroupUse(json) - case "Stmt_Foreach" => readForeach(json) - case "Stmt_TraitUse" => readTraitUse(json) - case "Stmt_Block" => NopStmt(PhpAttributes(json)) - case unhandled => NopStmt(PhpAttributes(json)) - - private def readString(json: Value): PhpString = - PhpString.withQuotes(json("value").str, PhpAttributes(json)) - - private def readInlineHtml(json: Value): PhpStmt = - val value = readString(json) - PhpEchoStmt(List(value), value.attributes) - - private def readBreakContinueNum(json: Value): Option[Int] = - Option.unless(json("num").isNull)(json("num")("value").toString).flatMap(_.toIntOption) - private def readBreak(json: Value): PhpBreakStmt = - val num = readBreakContinueNum(json) - PhpBreakStmt(num, PhpAttributes(json)) - - private def readContinue(json: Value): PhpContinueStmt = - val num = readBreakContinueNum(json) - PhpContinueStmt(num, PhpAttributes(json)) - - private def readWhile(json: Value): PhpWhileStmt = - val cond = readExpr(json("cond")) - val stmts = json("stmts").arr.toList.map(readStmt) - PhpWhileStmt(cond, stmts, PhpAttributes(json)) - - private def readDo(json: Value): PhpDoStmt = - val cond = readExpr(json("cond")) - val stmts = json("stmts").arr.toList.map(readStmt) - PhpDoStmt(cond, stmts, PhpAttributes(json)) - - private def readFor(json: Value): PhpForStmt = - val inits = json("init").arr.map(readExpr).toList - val conditions = json("cond").arr.map(readExpr).toList - val loopExprs = json("loop").arr.map(readExpr).toList - val bodyStmts = json("stmts").arr.map(readStmt).toList - - PhpForStmt(inits, conditions, loopExprs, bodyStmts, PhpAttributes(json)) - - private def readIf(json: Value): PhpIfStmt = - val condition = readExpr(json("cond")) - val stmts = json("stmts").arr.map(readStmt).toList - val elseIfs = json("elseifs").arr.map(readElseIf).toList - val elseStmt = Option.when(!json("else").isNull)(readElse(json("else"))) - - PhpIfStmt(condition, stmts, elseIfs, elseStmt, PhpAttributes(json)) - - private def readSwitch(json: Value): PhpSwitchStmt = - val condition = readExpr(json("cond")) - val cases = json("cases").arr.map(readCase).toList - - PhpSwitchStmt(condition, cases, PhpAttributes(json)) - - private def readTry(json: Value): PhpTryStmt = - val stmts = json("stmts").arr.map(readStmt).toList - val catches = json("catches").arr.map(readCatch).toList - val finallyStmt = Option.unless(json("finally").isNull)(readFinally(json("finally"))) - - PhpTryStmt(stmts, catches, finallyStmt, PhpAttributes(json)) - - private def readThrow(json: Value): PhpThrowExpr = - val expr = readExpr(json("expr")) - - PhpThrowExpr(expr, PhpAttributes(json)) - - private def readList(json: Value): PhpListExpr = - val items = - json("items").arr.map(item => Option.unless(item.isNull)(readArrayItem(item))).toList - - PhpListExpr(items, PhpAttributes(json)) - - private def readNew(json: Value): PhpNewExpr = - val classNode = - if json("class")("nodeType").strOpt.contains("Stmt_Class") then - readClassLike(json("class"), ClassLikeTypes.Class) - else - readNameOrExpr(json, "class") - - val args = json("args").arr.map(readCallArg).toList - - PhpNewExpr(classNode, args, PhpAttributes(json)) - - private def readInclude(json: Value): PhpIncludeExpr = - val expr = readExpr(json("expr")) - val includeType = json("type").num.toInt match - case 1 => PhpIncludeType.Include - case 2 => PhpIncludeType.IncludeOnce - case 3 => PhpIncludeType.Require - case 4 => PhpIncludeType.RequireOnce - case other => - logger.debug(s"Unhandled include type: $other. Defaulting to regular include.") - PhpIncludeType.Include - - PhpIncludeExpr(expr, includeType, PhpAttributes(json)) - - private def readMatch(json: Value): PhpMatchExpr = - val condition = readExpr(json("cond")) - val matchArms = json("arms").arr.map(readMatchArm).toList - - PhpMatchExpr(condition, matchArms, PhpAttributes(json)) - - private def readMatchArm(json: Value): PhpMatchArm = - val conditions = json("conds") match - case ujson.Null => Nil - case conds => conds.arr.map(readExpr).toList - - val isDefault = json("conds").isNull - val body = readExpr(json("body")) - - PhpMatchArm(conditions, body, isDefault, PhpAttributes(json)) - - private def readYield(json: Value): PhpYieldExpr = - val key = Option.unless(json("key").isNull)(readExpr(json("key"))) - val value = Option.unless(json("value").isNull)(readExpr(json("value"))) - - PhpYieldExpr(key, value, PhpAttributes(json)) - - private def readYieldFrom(json: Value): PhpYieldFromExpr = - val expr = readExpr(json("expr")) + def isBinaryOpType(typeName: String): Boolean = + BinaryOpTypeMap.contains(typeName) + end PhpBinaryOp + final case class PhpUnaryOp(operator: String, expr: PhpExpr, attributes: PhpAttributes) + extends PhpExpr + object PhpUnaryOp: + val UnaryOpTypeMap: Map[String, String] = Map( + "Expr_BitwiseNot" -> Operators.not, + "Expr_BooleanNot" -> Operators.logicalNot, + "Expr_PostDec" -> Operators.postDecrement, + "Expr_PostInc" -> Operators.postIncrement, + "Expr_PreDec" -> Operators.preDecrement, + "Expr_PreInc" -> Operators.preIncrement, + "Expr_UnaryMinus" -> Operators.minus, + "Expr_UnaryPlus" -> Operators.plus + ) - PhpYieldFromExpr(expr, PhpAttributes(json)) + def isUnaryOpType(typeName: String): Boolean = + UnaryOpTypeMap.contains(typeName) + final case class PhpTernaryOp( + condition: PhpExpr, + thenExpr: Option[PhpExpr], + elseExpr: PhpExpr, + attributes: PhpAttributes + ) extends PhpExpr + + object PhpAssignment: + val AssignTypeMap: Map[String, String] = Map( + "Expr_Assign" -> Operators.assignment, + "Expr_AssignRef" -> Operators.assignment, + "Expr_AssignOp_BitwiseAnd" -> Operators.assignmentAnd, + "Expr_AssignOp_BitwiseOr" -> Operators.assignmentOr, + "Expr_AssignOp_BitwiseXor" -> Operators.assignmentXor, + "Expr_AssignOp_Coalesce" -> PhpOperators.assignmentCoalesceOp, + "Expr_AssignOp_Concat" -> PhpOperators.assignmentConcatOp, + "Expr_AssignOp_Div" -> Operators.assignmentDivision, + "Expr_AssignOp_Minus" -> Operators.assignmentMinus, + "Expr_AssignOp_Mod" -> Operators.assignmentModulo, + "Expr_AssignOp_Mul" -> Operators.assignmentMultiplication, + "Expr_AssignOp_Plus" -> Operators.assignmentPlus, + "Expr_AssignOp_Pow" -> Operators.assignmentExponentiation, + "Expr_AssignOp_ShiftLeft" -> Operators.assignmentShiftLeft, + "Expr_AssignOp_ShiftRight" -> Operators.assignmentArithmeticShiftRight + ) - private def readClosure(json: Value): PhpClosureExpr = - val params = json("params").arr.map(readParam).toList - val stmts = json("stmts").arr.map(readStmt).toList - val returnType = Option.unless(json("returnType").isNull)(readType(json("returnType"))) - val uses = json("uses").arr.map(readClosureUse).toList - val isStatic = json("static").bool - val isByRef = json("byRef").bool - val isArrowFunc = false + def isAssignType(typeName: String): Boolean = + AssignTypeMap.contains(typeName) + end PhpAssignment + final case class PhpAssignment( + assignOp: String, + target: PhpExpr, + source: PhpExpr, + isRefAssign: Boolean, + attributes: PhpAttributes + ) extends PhpExpr + + final case class PhpCast(typ: String, expr: PhpExpr, attributes: PhpAttributes) extends PhpExpr + object PhpCast: + val CastTypeMap: Map[String, String] = Map( + "Expr_Cast_Array" -> PhpDomainTypeConstants.array, + "Expr_Cast_Bool" -> PhpDomainTypeConstants.bool, + "Expr_Cast_Double" -> PhpDomainTypeConstants.double, + "Expr_Cast_Int" -> PhpDomainTypeConstants.int, + "Expr_Cast_Object" -> PhpDomainTypeConstants.obj, + "Expr_Cast_String" -> PhpDomainTypeConstants.string, + "Expr_Cast_Unset" -> PhpDomainTypeConstants.unset + ) - PhpClosureExpr( - params, - stmts, - returnType, - uses, - isStatic, - isByRef, - isArrowFunc, - PhpAttributes(json) - ) - end readClosure - - private def readClosureUse(json: Value): PhpClosureUse = - val variable = readVariable(json("var")) - val isByRef = json("byRef").bool - - PhpClosureUse(variable, isByRef, PhpAttributes(json)) - - private def readClassConstFetch(json: Value): PhpClassConstFetchExpr = - val classNameType = json("class")("nodeType").str - val className = - if classNameType.startsWith("Name") then - readName(json("class")) - else - readExpr(json("class")) - - val constantName = json("name") match - case str: Str => Some(PhpNameExpr(str.value, PhpAttributes(json))) - case obj: Obj if obj("nodeType").strOpt.contains("Expr_Error") => None - case obj: Obj => Some(readName(obj)) - case other => throw new NotImplementedError( - s"unexpected constant name '$other' of type ${other.getClass}" - ) - - PhpClassConstFetchExpr(className, constantName, PhpAttributes(json)) - - private def readConstFetch(json: Value): PhpConstFetchExpr = - val name = readName(json("name")) - - PhpConstFetchExpr(name, PhpAttributes(json)) - - private def readArray(json: Value): PhpArrayExpr = - val items = json("items").arr.map { item => - Option.unless(item.isNull)(readArrayItem(item)) - }.toList - PhpArrayExpr(items, PhpAttributes(json)) - - private def readArrayItem(json: Value): PhpArrayItem = - val key = Option.unless(json("key").isNull)(readExpr(json("key"))) - val value = readExpr(json("value")) - val byRef = json("byRef").bool - val unpack = json("byRef").bool - - PhpArrayItem(key, value, byRef, unpack, PhpAttributes(json)) - - private def readArrayDimFetch(json: Value): PhpArrayDimFetchExpr = - val variable = readExpr(json("var")) - val dimension = Option.unless(json("dim").isNull)(readExpr(json("dim"))) - - PhpArrayDimFetchExpr(variable, dimension, PhpAttributes(json)) - - private def readErrorSuppress(json: Value): PhpErrorSuppressExpr = - val expr = readExpr(json("expr")) - PhpErrorSuppressExpr(expr, PhpAttributes(json)) - - private def readInstanceOf(json: Value): PhpInstanceOfExpr = - val expr = readExpr(json("expr")) - val className = readNameOrExpr(json, "class") - - PhpInstanceOfExpr(expr, className, PhpAttributes(json)) - - private def readShellExec(json: Value): PhpShellExecExpr = - val parts = readEncapsed(json) - - PhpShellExecExpr(parts, PhpAttributes(json)) - - private def readArrowFunction(json: Value): PhpClosureExpr = - val params = json("params").arr.map(readParam).toList - val expr = readExpr(json("expr")) - val returnType = Option.unless(json("returnType").isNull)(readType(json("returnType"))) - val isStatic = json("static").bool - val returnByRef = json("byRef").bool - val uses = Nil // Not defined for arrow shorthand - val isArrowFunc = true - - // Introduce a return here to keep arrow functions consistent with regular closures while allowing easy code re-use. - val syntheticReturn = PhpReturnStmt(Some(expr), expr.attributes) - PhpClosureExpr( - params, - syntheticReturn :: Nil, - returnType, - uses, - isStatic, - returnByRef, - isArrowFunc, - PhpAttributes(json) - ) - end readArrowFunction - - private def readPropertyFetch( - json: Value, - isNullsafe: Boolean = false, - isStatic: Boolean = false - ): PhpPropertyFetchExpr = - val expr = - if json.obj.contains("var") then - readExpr(json("var")) - else - readNameOrExpr(json, "class") - - val name = readNameOrExpr(json, "name") - - PhpPropertyFetchExpr(expr, name, isNullsafe, isStatic, PhpAttributes(json)) - - private def readReturn(json: Value): PhpReturnStmt = - val expr = Option.unless(json("expr").isNull)(readExpr(json("expr"))) - - PhpReturnStmt(expr, PhpAttributes(json)) - - private def extendsForClassLike(json: Value): List[PhpNameExpr] = - json.obj - .get("extends") - .map { - case ujson.Null => Nil - case arr: ujson.Arr => arr.arr.map(readName).toList - case obj: ujson.Obj => readName(obj) :: Nil - case other => throw new NotImplementedError( - s"unexpected 'extends' entry '$other' of type ${other.getClass}" - ) - } - .getOrElse(Nil) - - private def readClassLike(json: Value, classLikeType: String): PhpClassLikeStmt = - val name = Option.unless(json("name").isNull)(readName(json("name"))) - val modifiers = PhpModifiers.getModifierSet(json) - - val extendsNames = extendsForClassLike(json) - - val implements = json.obj.get("implements").map(_.arr.toList).getOrElse(Nil).map(readName) - val stmts = json("stmts").arr.map(readStmt).toList - - val scalarType = - json.obj.get("scalarType").flatMap(typ => Option.unless(typ.isNull)(readName(typ))) - - val hasConstructor = classLikeType == ClassLikeTypes.Class - - val attributes = PhpAttributes(json) - - PhpClassLikeStmt( - name, - modifiers, - extendsNames, - implements, - stmts, - classLikeType, - scalarType, - hasConstructor, - attributes - ) - end readClassLike - - private def readEnumCase(json: Value): PhpEnumCaseStmt = - val name = readName(json("name")) - val expr = Option.unless(json("expr").isNull)(readExpr(json("expr"))) - - PhpEnumCaseStmt(name, expr, PhpAttributes(json)) - - private def readCatch(json: Value): PhpCatchStmt = - val types = json("types").arr.map(readName).toList - val variable = Option.unless(json("var").isNull)(readExpr(json("var"))) - val stmts = json("stmts").arr.map(readStmt).toList - - PhpCatchStmt(types, variable, stmts, PhpAttributes(json)) - - private def readFinally(json: Value): PhpFinallyStmt = - val stmts = json("stmts").arr.map(readStmt).toList - - PhpFinallyStmt(stmts, PhpAttributes(json)) - - private def readCase(json: Value): PhpCaseStmt = - val condition = Option.unless(json("cond").isNull)(readExpr(json("cond"))) - val stmts = json("stmts").arr.map(readStmt).toList - - PhpCaseStmt(condition, stmts, PhpAttributes(json)) - - private def readElseIf(json: Value): PhpElseIfStmt = - val condition = readExpr(json("cond")) - val stmts = json("stmts").arr.map(readStmt).toList - - PhpElseIfStmt(condition, stmts, PhpAttributes(json)) - - private def readElse(json: Value): PhpElseStmt = - val stmts = json("stmts").arr.map(readStmt).toList - - PhpElseStmt(stmts, PhpAttributes(json)) - - private def readEncapsed(json: Value): PhpEncapsed = - PhpEncapsed(json("parts").arr.map(readExpr).toSeq, PhpAttributes(json)) - - private def readMagicConst(json: Value): PhpConstFetchExpr = - val name = json("nodeType").str match - case "Scalar_MagicConst_Class" => "__CLASS__" - case "Scalar_MagicConst_Dir" => "__DIR__" - case "Scalar_MagicConst_File" => "__FILE__" - case "Scalar_MagicConst_Function" => "__FUNCTION__" - case "Scalar_MagicConst_Line" => "__LINE__" - case "Scalar_MagicConst_Method" => "__METHOD__" - case "Scalar_MagicConst_Namespace" => "__NAMESPACE__" - case "Scalar_MagicConst_Trait" => "__TRAIT__" - - val attributes = PhpAttributes(json) - - PhpConstFetchExpr(PhpNameExpr(name, attributes), attributes) - - private def readExpr(json: Value): PhpExpr = - json("nodeType").str match - case "Scalar_String" => readString(json) - case "Scalar_DNumber" => PhpFloat(json("value").toString, PhpAttributes(json)) - case "Scalar_Float" => PhpFloat(json("value").toString, PhpAttributes(json)) - case "Scalar_LNumber" => PhpInt(json("value").toString, PhpAttributes(json)) - case "Scalar_Int" => PhpInt(json("value").toString, PhpAttributes(json)) - case "Scalar_Encapsed" => readEncapsed(json) - case "Scalar_InterpolatedString" => readEncapsed(json) - case "Scalar_EncapsedStringPart" => readString(json) - case "InterpolatedStringPart" => readString(json) - - case typ if typ.startsWith("Scalar_MagicConst") => readMagicConst(json) - - case "Expr_FuncCall" => readCall(json) - case "Expr_MethodCall" => readCall(json) - case "Expr_NullsafeMethodCall" => readCall(json) - case "Expr_StaticCall" => readCall(json) - - case "Expr_Clone" => readClone(json) - case "Expr_Empty" => readEmpty(json) - case "Expr_Eval" => readEval(json) - case "Expr_Exit" => readExit(json) - case "Expr_Variable" => readVariable(json) - case "Expr_Isset" => readIsset(json) - case "Expr_Print" => readPrint(json) - case "Expr_Ternary" => readTernaryOp(json) - case "Expr_Throw" => readThrow(json) - case "Expr_List" => readList(json) - case "Expr_New" => readNew(json) - case "Expr_Include" => readInclude(json) - case "Expr_Match" => readMatch(json) - case "Expr_Yield" => readYield(json) - case "Expr_YieldFrom" => readYieldFrom(json) - case "Expr_Closure" => readClosure(json) - - case "Expr_ClassConstFetch" => readClassConstFetch(json) - case "Expr_ConstFetch" => readConstFetch(json) - - case "Expr_Array" => readArray(json) - case "Expr_ArrayDimFetch" => readArrayDimFetch(json) - case "Expr_ErrorSuppress" => readErrorSuppress(json) - case "Expr_Instanceof" => readInstanceOf(json) - case "Expr_ShellExec" => readShellExec(json) - case "Expr_ArrowFunction" => readArrowFunction(json) - - case "Expr_PropertyFetch" => readPropertyFetch(json) - case "Expr_NullsafePropertyFetch" => readPropertyFetch(json, isNullsafe = true) - case "Expr_StaticPropertyFetch" => readPropertyFetch(json, isStatic = true) - - case typ if isUnaryOpType(typ) => readUnaryOp(json) - case typ if isBinaryOpType(typ) => readBinaryOp(json) - case typ if isAssignType(typ) => readAssign(json) - case typ if isCastType(typ) => readCast(json) - - case unhandled => - logger.debug(s"Found unhandled expr type: $unhandled") - ??? - - private def readClone(json: Value): PhpCloneExpr = - val expr = readExpr(json("expr")) - PhpCloneExpr(expr, PhpAttributes(json)) - - private def readEmpty(json: Value): PhpEmptyExpr = - val expr = readExpr(json("expr")) - PhpEmptyExpr(expr, PhpAttributes(json)) - - private def readEval(json: Value): PhpEvalExpr = - val expr = readExpr(json("expr")) - PhpEvalExpr(expr, PhpAttributes(json)) - - private def readExit(json: Value): PhpExitExpr = - val expr = Option.unless(json("expr").isNull)(readExpr(json("expr"))) - PhpExitExpr(expr, PhpAttributes(json)) - - private def readVariable(json: Value): PhpVariable = - if !json.obj.contains("name") then - logger.debug(s"Variable did not contain name: $json") - val varAttrs = PhpAttributes(json) - val name = json("name") match - case Str(value) => readName(value).copy(attributes = varAttrs) - case Obj(_) => readNameOrExpr(json, "name") - case value => readExpr(value) - PhpVariable(name, varAttrs) - - private def readIsset(json: Value): PhpIsset = - val vars = json("vars").arr.map(readExpr).toList - PhpIsset(vars, PhpAttributes(json)) - - private def readPrint(json: Value): PhpPrint = - val expr = readExpr(json("expr")) - PhpPrint(expr, PhpAttributes(json)) - - private def readTernaryOp(json: Value): PhpTernaryOp = - val condition = readExpr(json("cond")) - val maybeThenExpr = Option.unless(json("if").isNull)(readExpr(json("if"))) - val elseExpr = readExpr(json("else")) - - PhpTernaryOp(condition, maybeThenExpr, elseExpr, PhpAttributes(json)) - - private def readNameOrExpr(json: Value, fieldName: String): PhpExpr = - val field = json(fieldName) - if field("nodeType").str.startsWith("Name") then - readName(field) - else if field("nodeType").str == "Identifier" then - readName(field) - else if field("nodeType").str == "VarLikeIdentifier" then - readVariable(field) + def isCastType(typeName: String): Boolean = + CastTypeMap.contains(typeName) + + final case class PhpIsset(vars: Seq[PhpExpr], attributes: PhpAttributes) extends PhpExpr + final case class PhpPrint(expr: PhpExpr, attributes: PhpAttributes) extends PhpExpr + + sealed trait PhpScalar extends PhpExpr + sealed abstract class PhpSimpleScalar(val typeFullName: String) extends PhpScalar: + def value: String + def attributes: PhpAttributes + + final case class PhpString(val value: String, val attributes: PhpAttributes) + extends PhpSimpleScalar(TypeConstants.String) + object PhpString: + def withQuotes(value: String, attributes: PhpAttributes): PhpString = + PhpString(s"\"${escapeString(value)}\"", attributes) + + final case class PhpInt(val value: String, val attributes: PhpAttributes) + extends PhpSimpleScalar(TypeConstants.Int) + + final case class PhpFloat(val value: String, val attributes: PhpAttributes) + extends PhpSimpleScalar(TypeConstants.Float) + + final case class PhpEncapsed(parts: Seq[PhpExpr], attributes: PhpAttributes) extends PhpScalar + + final case class PhpThrowExpr(expr: PhpExpr, attributes: PhpAttributes) extends PhpExpr + final case class PhpListExpr(items: List[Option[PhpArrayItem]], attributes: PhpAttributes) + extends PhpExpr + + final case class PhpClassConstFetchExpr( + className: PhpExpr, + constantName: Option[PhpNameExpr], + attributes: PhpAttributes + ) extends PhpExpr + + final case class PhpConstFetchExpr(name: PhpNameExpr, attributes: PhpAttributes) extends PhpExpr + + final case class PhpArrayExpr(items: List[Option[PhpArrayItem]], attributes: PhpAttributes) + extends PhpExpr + final case class PhpArrayItem( + key: Option[PhpExpr], + value: PhpExpr, + byRef: Boolean, + unpack: Boolean, + attributes: PhpAttributes + ) extends PhpExpr + final case class PhpArrayDimFetchExpr( + variable: PhpExpr, + dimension: Option[PhpExpr], + attributes: PhpAttributes + ) extends PhpExpr + + final case class PhpErrorSuppressExpr(expr: PhpExpr, attributes: PhpAttributes) extends PhpExpr + + final case class PhpInstanceOfExpr(expr: PhpExpr, className: PhpExpr, attributes: PhpAttributes) + extends PhpExpr + + final case class PhpShellExecExpr(parts: PhpEncapsed, attributes: PhpAttributes) extends PhpExpr + + final case class PhpPropertyFetchExpr( + expr: PhpExpr, + name: PhpExpr, + isNullsafe: Boolean, + isStatic: Boolean, + attributes: PhpAttributes + ) extends PhpExpr + + final case class PhpMatchExpr( + condition: PhpExpr, + matchArms: List[PhpMatchArm], + attributes: PhpAttributes + ) extends PhpExpr + + final case class PhpMatchArm( + conditions: List[PhpExpr], + body: PhpExpr, + isDefault: Boolean, + attributes: PhpAttributes + ) extends PhpExpr + + final case class PhpYieldExpr( + key: Option[PhpExpr], + value: Option[PhpExpr], + attributes: PhpAttributes + ) extends PhpExpr + final case class PhpYieldFromExpr(expr: PhpExpr, attributes: PhpAttributes) extends PhpExpr + + final case class PhpClosureExpr( + params: List[PhpParam], + stmts: List[PhpStmt], + returnType: Option[PhpNameExpr], + uses: List[PhpClosureUse], + isStatic: Boolean, + returnByRef: Boolean, + isArrowFunc: Boolean, + attributes: PhpAttributes + ) extends PhpExpr + final case class PhpClosureUse(variable: PhpExpr, byRef: Boolean, attributes: PhpAttributes) + extends PhpExpr + + private def escapeString(value: String): String = + value + .replace("\\", "\\\\") + .replace("\n", "\\n") + .replace("\b", "\\b") + .replace("\r", "\\r") + .replace("\t", "\\t") + .replace("\'", "\\'") + .replace("\f", "\\f") + .replace("\"", "\\\"") + + private def readFile(json: Value): PhpFile = + json match + case arr: Arr => + val children = arr.value.map(readStmt).toList + PhpFile(children) + case unhandled => + logger.debug( + s"Found unhandled type in readFile: ${unhandled.getClass} with value $unhandled" + ) + ??? + + private def readStmt(json: Value): PhpStmt = + json("nodeType").str match + case "Stmt_Echo" => + val values = json("exprs").arr.map(readExpr).toSeq + PhpEchoStmt(values, PhpAttributes(json)) + case "Stmt_Expression" => readExpr(json("expr")) + case "Stmt_Function" => readFunction(json) + case "Stmt_InlineHTML" => readInlineHtml(json) + case "Stmt_Break" => readBreak(json) + case "Stmt_Continue" => readContinue(json) + case "Stmt_While" => readWhile(json) + case "Stmt_Do" => readDo(json) + case "Stmt_For" => readFor(json) + case "Stmt_If" => readIf(json) + case "Stmt_Switch" => readSwitch(json) + case "Stmt_TryCatch" => readTry(json) + case "Stmt_Throw" => readThrow(json) + case "Stmt_Return" => readReturn(json) + case "Stmt_Class" => readClassLike(json, ClassLikeTypes.Class) + case "Stmt_Interface" => readClassLike(json, ClassLikeTypes.Interface) + case "Stmt_Trait" => readClassLike(json, ClassLikeTypes.Trait) + case "Stmt_Enum" => readClassLike(json, ClassLikeTypes.Enum) + case "Stmt_EnumCase" => readEnumCase(json) + case "Stmt_ClassMethod" => readClassMethod(json) + case "Stmt_Property" => readProperty(json) + case "Stmt_ClassConst" => readConst(json) + case "Stmt_Const" => readConst(json) + case "Stmt_Goto" => readGoto(json) + case "Stmt_Label" => readLabel(json) + case "Stmt_HaltCompiler" => readHaltCompiler(json) + case "Stmt_Namespace" => readNamespace(json) + case "Stmt_Nop" => NopStmt(PhpAttributes(json)) + case "Stmt_Declare" => readDeclare(json) + case "Stmt_Unset" => readUnset(json) + case "Stmt_Static" => readStatic(json) + case "Stmt_Global" => readGlobal(json) + case "Stmt_Use" => readUse(json) + case "Stmt_GroupUse" => readGroupUse(json) + case "Stmt_Foreach" => readForeach(json) + case "Stmt_TraitUse" => readTraitUse(json) + case "Stmt_Block" => NopStmt(PhpAttributes(json)) + case unhandled => NopStmt(PhpAttributes(json)) + + private def readString(json: Value): PhpString = + PhpString.withQuotes(json("value").str, PhpAttributes(json)) + + private def readInlineHtml(json: Value): PhpStmt = + val value = readString(json) + PhpEchoStmt(List(value), value.attributes) + + private def readBreakContinueNum(json: Value): Option[Int] = + Option.unless(json("num").isNull)(json("num")("value").toString).flatMap(_.toIntOption) + private def readBreak(json: Value): PhpBreakStmt = + val num = readBreakContinueNum(json) + PhpBreakStmt(num, PhpAttributes(json)) + + private def readContinue(json: Value): PhpContinueStmt = + val num = readBreakContinueNum(json) + PhpContinueStmt(num, PhpAttributes(json)) + + private def readWhile(json: Value): PhpWhileStmt = + val cond = readExpr(json("cond")) + val stmts = json("stmts").arr.toList.map(readStmt) + PhpWhileStmt(cond, stmts, PhpAttributes(json)) + + private def readDo(json: Value): PhpDoStmt = + val cond = readExpr(json("cond")) + val stmts = json("stmts").arr.toList.map(readStmt) + PhpDoStmt(cond, stmts, PhpAttributes(json)) + + private def readFor(json: Value): PhpForStmt = + val inits = json("init").arr.map(readExpr).toList + val conditions = json("cond").arr.map(readExpr).toList + val loopExprs = json("loop").arr.map(readExpr).toList + val bodyStmts = json("stmts").arr.map(readStmt).toList + + PhpForStmt(inits, conditions, loopExprs, bodyStmts, PhpAttributes(json)) + + private def readIf(json: Value): PhpIfStmt = + val condition = readExpr(json("cond")) + val stmts = json("stmts").arr.map(readStmt).toList + val elseIfs = json("elseifs").arr.map(readElseIf).toList + val elseStmt = Option.when(!json("else").isNull)(readElse(json("else"))) + + PhpIfStmt(condition, stmts, elseIfs, elseStmt, PhpAttributes(json)) + + private def readSwitch(json: Value): PhpSwitchStmt = + val condition = readExpr(json("cond")) + val cases = json("cases").arr.map(readCase).toList + + PhpSwitchStmt(condition, cases, PhpAttributes(json)) + + private def readTry(json: Value): PhpTryStmt = + val stmts = json("stmts").arr.map(readStmt).toList + val catches = json("catches").arr.map(readCatch).toList + val finallyStmt = Option.unless(json("finally").isNull)(readFinally(json("finally"))) + + PhpTryStmt(stmts, catches, finallyStmt, PhpAttributes(json)) + + private def readThrow(json: Value): PhpThrowExpr = + val expr = readExpr(json("expr")) + + PhpThrowExpr(expr, PhpAttributes(json)) + + private def readList(json: Value): PhpListExpr = + val items = + json("items").arr.map(item => Option.unless(item.isNull)(readArrayItem(item))).toList + + PhpListExpr(items, PhpAttributes(json)) + + private def readNew(json: Value): PhpNewExpr = + val classNode = + if json("class")("nodeType").strOpt.contains("Stmt_Class") then + readClassLike(json("class"), ClassLikeTypes.Class) else - readExpr(field) - - private def readCall(json: Value): PhpCallExpr = - val jsonMap = json.obj - val nodeType = json("nodeType").str - val args = json("args").arr.map(readCallArg).toSeq - - val target = - jsonMap.get("var").map(readExpr).orElse(jsonMap.get("class").map(_ => - readNameOrExpr(jsonMap, "class") - )) - - val methodName = readNameOrExpr(json, "name") - - val isNullSafe = nodeType == "Expr_NullsafeMethodCall" - val isStatic = nodeType == "Expr_StaticCall" - - PhpCallExpr(target, methodName, args, isNullSafe, isStatic, PhpAttributes(json)) - - private def readFunction(json: Value): PhpMethodDecl = - val returnByRef = json("byRef").bool - val name = readName(json("name")) - val params = json("params").arr.map(readParam).toList - val returnType = Option.unless(json("returnType").isNull)(readType(json("returnType"))) - val stmts = json("stmts").arr.map(readStmt).toList - // Only class methods have modifiers - val modifiers = Nil - val namespacedName = - Option.unless(json("namespacedName").isNull)(readName(json("namespacedName"))) - val isClassMethod = false - - PhpMethodDecl( - name, - params, - modifiers, - returnType, - stmts, - returnByRef, - namespacedName, - isClassMethod, - PhpAttributes(json) - ) - end readFunction - - private def readClassMethod(json: Value): PhpMethodDecl = - val modifiers = PhpModifiers.getModifierSet(json) - val returnByRef = json("byRef").bool - val name = readName(json("name")) - val params = json("params").arr.map(readParam).toList - val returnType = Option.unless(json("returnType").isNull)(readType(json("returnType"))) - val stmts = - if json("stmts").isNull then - Nil - else - json("stmts").arr.map(readStmt).toList - - val namespacedName = None // only defined for functions - val isClassMethod = true - - PhpMethodDecl( - name, - params, - modifiers, - returnType, - stmts, - returnByRef, - namespacedName, - isClassMethod, - PhpAttributes(json) - ) - end readClassMethod + readNameOrExpr(json, "class") + + val args = json("args").arr.map(readCallArg).toList + + PhpNewExpr(classNode, args, PhpAttributes(json)) + + private def readInclude(json: Value): PhpIncludeExpr = + val expr = readExpr(json("expr")) + val includeType = json("type").num.toInt match + case 1 => PhpIncludeType.Include + case 2 => PhpIncludeType.IncludeOnce + case 3 => PhpIncludeType.Require + case 4 => PhpIncludeType.RequireOnce + case other => + logger.debug(s"Unhandled include type: $other. Defaulting to regular include.") + PhpIncludeType.Include + + PhpIncludeExpr(expr, includeType, PhpAttributes(json)) + + private def readMatch(json: Value): PhpMatchExpr = + val condition = readExpr(json("cond")) + val matchArms = json("arms").arr.map(readMatchArm).toList + + PhpMatchExpr(condition, matchArms, PhpAttributes(json)) + + private def readMatchArm(json: Value): PhpMatchArm = + val conditions = json("conds") match + case ujson.Null => Nil + case conds => conds.arr.map(readExpr).toList + + val isDefault = json("conds").isNull + val body = readExpr(json("body")) + + PhpMatchArm(conditions, body, isDefault, PhpAttributes(json)) + + private def readYield(json: Value): PhpYieldExpr = + val key = Option.unless(json("key").isNull)(readExpr(json("key"))) + val value = Option.unless(json("value").isNull)(readExpr(json("value"))) + + PhpYieldExpr(key, value, PhpAttributes(json)) + + private def readYieldFrom(json: Value): PhpYieldFromExpr = + val expr = readExpr(json("expr")) + + PhpYieldFromExpr(expr, PhpAttributes(json)) + + private def readClosure(json: Value): PhpClosureExpr = + val params = json("params").arr.map(readParam).toList + val stmts = json("stmts").arr.map(readStmt).toList + val returnType = Option.unless(json("returnType").isNull)(readType(json("returnType"))) + val uses = json("uses").arr.map(readClosureUse).toList + val isStatic = json("static").bool + val isByRef = json("byRef").bool + val isArrowFunc = false + + PhpClosureExpr( + params, + stmts, + returnType, + uses, + isStatic, + isByRef, + isArrowFunc, + PhpAttributes(json) + ) + end readClosure - private def readProperty(json: Value): PhpPropertyStmt = - val modifiers = PhpModifiers.getModifierSet(json) - val variables = json("props").arr.map(readPropertyValue).toList - val typeName = Option.unless(json("type").isNull)(readType(json("type"))) + private def readClosureUse(json: Value): PhpClosureUse = + val variable = readVariable(json("var")) + val isByRef = json("byRef").bool - PhpPropertyStmt(modifiers, variables, typeName, PhpAttributes(json)) + PhpClosureUse(variable, isByRef, PhpAttributes(json)) - private def readPropertyValue(json: Value): PhpPropertyValue = - val name = readName(json("name")) - val defaultValue = Option.unless(json("default").isNull)(readExpr(json("default"))) + private def readClassConstFetch(json: Value): PhpClassConstFetchExpr = + val classNameType = json("class")("nodeType").str + val className = + if classNameType.startsWith("Name") then + readName(json("class")) + else + readExpr(json("class")) + + val constantName = json("name") match + case str: Str => Some(PhpNameExpr(str.value, PhpAttributes(json))) + case obj: Obj if obj("nodeType").strOpt.contains("Expr_Error") => None + case obj: Obj => Some(readName(obj)) + case other => throw new NotImplementedError( + s"unexpected constant name '$other' of type ${other.getClass}" + ) + + PhpClassConstFetchExpr(className, constantName, PhpAttributes(json)) + + private def readConstFetch(json: Value): PhpConstFetchExpr = + val name = readName(json("name")) + + PhpConstFetchExpr(name, PhpAttributes(json)) + + private def readArray(json: Value): PhpArrayExpr = + val items = json("items").arr.map { item => + Option.unless(item.isNull)(readArrayItem(item)) + }.toList + PhpArrayExpr(items, PhpAttributes(json)) + + private def readArrayItem(json: Value): PhpArrayItem = + val key = Option.unless(json("key").isNull)(readExpr(json("key"))) + val value = readExpr(json("value")) + val byRef = json("byRef").bool + val unpack = json("byRef").bool + + PhpArrayItem(key, value, byRef, unpack, PhpAttributes(json)) + + private def readArrayDimFetch(json: Value): PhpArrayDimFetchExpr = + val variable = readExpr(json("var")) + val dimension = Option.unless(json("dim").isNull)(readExpr(json("dim"))) + + PhpArrayDimFetchExpr(variable, dimension, PhpAttributes(json)) + + private def readErrorSuppress(json: Value): PhpErrorSuppressExpr = + val expr = readExpr(json("expr")) + PhpErrorSuppressExpr(expr, PhpAttributes(json)) + + private def readInstanceOf(json: Value): PhpInstanceOfExpr = + val expr = readExpr(json("expr")) + val className = readNameOrExpr(json, "class") + + PhpInstanceOfExpr(expr, className, PhpAttributes(json)) + + private def readShellExec(json: Value): PhpShellExecExpr = + val parts = readEncapsed(json) + + PhpShellExecExpr(parts, PhpAttributes(json)) + + private def readArrowFunction(json: Value): PhpClosureExpr = + val params = json("params").arr.map(readParam).toList + val expr = readExpr(json("expr")) + val returnType = Option.unless(json("returnType").isNull)(readType(json("returnType"))) + val isStatic = json("static").bool + val returnByRef = json("byRef").bool + val uses = Nil // Not defined for arrow shorthand + val isArrowFunc = true + + // Introduce a return here to keep arrow functions consistent with regular closures while allowing easy code re-use. + val syntheticReturn = PhpReturnStmt(Some(expr), expr.attributes) + PhpClosureExpr( + params, + syntheticReturn :: Nil, + returnType, + uses, + isStatic, + returnByRef, + isArrowFunc, + PhpAttributes(json) + ) + end readArrowFunction + + private def readPropertyFetch( + json: Value, + isNullsafe: Boolean = false, + isStatic: Boolean = false + ): PhpPropertyFetchExpr = + val expr = + if json.obj.contains("var") then + readExpr(json("var")) + else + readNameOrExpr(json, "class") - PhpPropertyValue(name, defaultValue, PhpAttributes(json)) + val name = readNameOrExpr(json, "name") - private def readConst(json: Value): PhpConstStmt = - val modifiers = PhpModifiers.getModifierSet(json) + PhpPropertyFetchExpr(expr, name, isNullsafe, isStatic, PhpAttributes(json)) - val constDeclarations = json("consts").arr.map(readConstDeclaration).toList + private def readReturn(json: Value): PhpReturnStmt = + val expr = Option.unless(json("expr").isNull)(readExpr(json("expr"))) - PhpConstStmt(modifiers, constDeclarations, PhpAttributes(json)) + PhpReturnStmt(expr, PhpAttributes(json)) - private def readGoto(json: Value): PhpGotoStmt = - val name = readName(json("name")) - PhpGotoStmt(name, PhpAttributes(json)) + private def extendsForClassLike(json: Value): List[PhpNameExpr] = + json.obj + .get("extends") + .map { + case ujson.Null => Nil + case arr: ujson.Arr => arr.arr.map(readName).toList + case obj: ujson.Obj => readName(obj) :: Nil + case other => throw new NotImplementedError( + s"unexpected 'extends' entry '$other' of type ${other.getClass}" + ) + } + .getOrElse(Nil) - private def readLabel(json: Value): PhpLabelStmt = - val name = readName(json("name")) - PhpLabelStmt(name, PhpAttributes(json)) + private def readClassLike(json: Value, classLikeType: String): PhpClassLikeStmt = + val name = Option.unless(json("name").isNull)(readName(json("name"))) + val modifiers = PhpModifiers.getModifierSet(json) - private def readHaltCompiler(json: Value): PhpHaltCompilerStmt = - // Ignore the remaining text here since it can get quite large (common use case is to separate code from data blob) - PhpHaltCompilerStmt(PhpAttributes(json)) + val extendsNames = extendsForClassLike(json) - private def readNamespace(json: Value): PhpNamespaceStmt = - val name = Option.unless(json("name").isNull)(readName(json("name"))) + val implements = json.obj.get("implements").map(_.arr.toList).getOrElse(Nil).map(readName) + val stmts = json("stmts").arr.map(readStmt).toList - val stmts = json("stmts") match - case ujson.Null => Nil - case stmts: Arr => stmts.arr.map(readStmt).toList - case unhandled => - logger.debug(s"Unhandled namespace stmts type $unhandled") - ??? + val scalarType = + json.obj.get("scalarType").flatMap(typ => Option.unless(typ.isNull)(readName(typ))) - PhpNamespaceStmt(name, stmts, PhpAttributes(json)) + val hasConstructor = classLikeType == ClassLikeTypes.Class - private def readDeclare(json: Value): PhpDeclareStmt = - val declares = json("declares").arr.map(readDeclareItem).toList - val stmts = Option.unless(json("stmts").isNull)(json("stmts").arr.map(readStmt).toList) + val attributes = PhpAttributes(json) - PhpDeclareStmt(declares, stmts, PhpAttributes(json)) + PhpClassLikeStmt( + name, + modifiers, + extendsNames, + implements, + stmts, + classLikeType, + scalarType, + hasConstructor, + attributes + ) + end readClassLike + + private def readEnumCase(json: Value): PhpEnumCaseStmt = + val name = readName(json("name")) + val expr = Option.unless(json("expr").isNull)(readExpr(json("expr"))) + + PhpEnumCaseStmt(name, expr, PhpAttributes(json)) + + private def readCatch(json: Value): PhpCatchStmt = + val types = json("types").arr.map(readName).toList + val variable = Option.unless(json("var").isNull)(readExpr(json("var"))) + val stmts = json("stmts").arr.map(readStmt).toList + + PhpCatchStmt(types, variable, stmts, PhpAttributes(json)) + + private def readFinally(json: Value): PhpFinallyStmt = + val stmts = json("stmts").arr.map(readStmt).toList + + PhpFinallyStmt(stmts, PhpAttributes(json)) + + private def readCase(json: Value): PhpCaseStmt = + val condition = Option.unless(json("cond").isNull)(readExpr(json("cond"))) + val stmts = json("stmts").arr.map(readStmt).toList + + PhpCaseStmt(condition, stmts, PhpAttributes(json)) + + private def readElseIf(json: Value): PhpElseIfStmt = + val condition = readExpr(json("cond")) + val stmts = json("stmts").arr.map(readStmt).toList + + PhpElseIfStmt(condition, stmts, PhpAttributes(json)) + + private def readElse(json: Value): PhpElseStmt = + val stmts = json("stmts").arr.map(readStmt).toList + + PhpElseStmt(stmts, PhpAttributes(json)) + + private def readEncapsed(json: Value): PhpEncapsed = + PhpEncapsed(json("parts").arr.map(readExpr).toSeq, PhpAttributes(json)) + + private def readMagicConst(json: Value): PhpConstFetchExpr = + val name = json("nodeType").str match + case "Scalar_MagicConst_Class" => "__CLASS__" + case "Scalar_MagicConst_Dir" => "__DIR__" + case "Scalar_MagicConst_File" => "__FILE__" + case "Scalar_MagicConst_Function" => "__FUNCTION__" + case "Scalar_MagicConst_Line" => "__LINE__" + case "Scalar_MagicConst_Method" => "__METHOD__" + case "Scalar_MagicConst_Namespace" => "__NAMESPACE__" + case "Scalar_MagicConst_Trait" => "__TRAIT__" + + val attributes = PhpAttributes(json) + + PhpConstFetchExpr(PhpNameExpr(name, attributes), attributes) + + private def readExpr(json: Value): PhpExpr = + json("nodeType").str match + case "Scalar_String" => readString(json) + case "Scalar_DNumber" => PhpFloat(json("value").toString, PhpAttributes(json)) + case "Scalar_Float" => PhpFloat(json("value").toString, PhpAttributes(json)) + case "Scalar_LNumber" => PhpInt(json("value").toString, PhpAttributes(json)) + case "Scalar_Int" => PhpInt(json("value").toString, PhpAttributes(json)) + case "Scalar_Encapsed" => readEncapsed(json) + case "Scalar_InterpolatedString" => readEncapsed(json) + case "Scalar_EncapsedStringPart" => readString(json) + case "InterpolatedStringPart" => readString(json) + + case typ if typ.startsWith("Scalar_MagicConst") => readMagicConst(json) + + case "Expr_FuncCall" => readCall(json) + case "Expr_MethodCall" => readCall(json) + case "Expr_NullsafeMethodCall" => readCall(json) + case "Expr_StaticCall" => readCall(json) + + case "Expr_Clone" => readClone(json) + case "Expr_Empty" => readEmpty(json) + case "Expr_Eval" => readEval(json) + case "Expr_Exit" => readExit(json) + case "Expr_Variable" => readVariable(json) + case "Expr_Isset" => readIsset(json) + case "Expr_Print" => readPrint(json) + case "Expr_Ternary" => readTernaryOp(json) + case "Expr_Throw" => readThrow(json) + case "Expr_List" => readList(json) + case "Expr_New" => readNew(json) + case "Expr_Include" => readInclude(json) + case "Expr_Match" => readMatch(json) + case "Expr_Yield" => readYield(json) + case "Expr_YieldFrom" => readYieldFrom(json) + case "Expr_Closure" => readClosure(json) + + case "Expr_ClassConstFetch" => readClassConstFetch(json) + case "Expr_ConstFetch" => readConstFetch(json) + + case "Expr_Array" => readArray(json) + case "Expr_ArrayDimFetch" => readArrayDimFetch(json) + case "Expr_ErrorSuppress" => readErrorSuppress(json) + case "Expr_Instanceof" => readInstanceOf(json) + case "Expr_ShellExec" => readShellExec(json) + case "Expr_ArrowFunction" => readArrowFunction(json) + + case "Expr_PropertyFetch" => readPropertyFetch(json) + case "Expr_NullsafePropertyFetch" => readPropertyFetch(json, isNullsafe = true) + case "Expr_StaticPropertyFetch" => readPropertyFetch(json, isStatic = true) + + case typ if isUnaryOpType(typ) => readUnaryOp(json) + case typ if isBinaryOpType(typ) => readBinaryOp(json) + case typ if isAssignType(typ) => readAssign(json) + case typ if isCastType(typ) => readCast(json) + + case unhandled => + logger.debug(s"Found unhandled expr type: $unhandled") + ??? + + private def readClone(json: Value): PhpCloneExpr = + val expr = readExpr(json("expr")) + PhpCloneExpr(expr, PhpAttributes(json)) + + private def readEmpty(json: Value): PhpEmptyExpr = + val expr = readExpr(json("expr")) + PhpEmptyExpr(expr, PhpAttributes(json)) + + private def readEval(json: Value): PhpEvalExpr = + val expr = readExpr(json("expr")) + PhpEvalExpr(expr, PhpAttributes(json)) + + private def readExit(json: Value): PhpExitExpr = + val expr = Option.unless(json("expr").isNull)(readExpr(json("expr"))) + PhpExitExpr(expr, PhpAttributes(json)) + + private def readVariable(json: Value): PhpVariable = + if !json.obj.contains("name") then + logger.debug(s"Variable did not contain name: $json") + val varAttrs = PhpAttributes(json) + val name = json("name") match + case Str(value) => readName(value).copy(attributes = varAttrs) + case Obj(_) => readNameOrExpr(json, "name") + case value => readExpr(value) + PhpVariable(name, varAttrs) + + private def readIsset(json: Value): PhpIsset = + val vars = json("vars").arr.map(readExpr).toList + PhpIsset(vars, PhpAttributes(json)) + + private def readPrint(json: Value): PhpPrint = + val expr = readExpr(json("expr")) + PhpPrint(expr, PhpAttributes(json)) + + private def readTernaryOp(json: Value): PhpTernaryOp = + val condition = readExpr(json("cond")) + val maybeThenExpr = Option.unless(json("if").isNull)(readExpr(json("if"))) + val elseExpr = readExpr(json("else")) + + PhpTernaryOp(condition, maybeThenExpr, elseExpr, PhpAttributes(json)) + + private def readNameOrExpr(json: Value, fieldName: String): PhpExpr = + val field = json(fieldName) + if field("nodeType").str.startsWith("Name") then + readName(field) + else if field("nodeType").str == "Identifier" then + readName(field) + else if field("nodeType").str == "VarLikeIdentifier" then + readVariable(field) + else + readExpr(field) + + private def readCall(json: Value): PhpCallExpr = + val jsonMap = json.obj + val nodeType = json("nodeType").str + val args = json("args").arr.map(readCallArg).toSeq + + val target = + jsonMap.get("var").map(readExpr).orElse(jsonMap.get("class").map(_ => + readNameOrExpr(jsonMap, "class") + )) + + val methodName = readNameOrExpr(json, "name") + + val isNullSafe = nodeType == "Expr_NullsafeMethodCall" + val isStatic = nodeType == "Expr_StaticCall" + + PhpCallExpr(target, methodName, args, isNullSafe, isStatic, PhpAttributes(json)) + + private def readFunction(json: Value): PhpMethodDecl = + val returnByRef = json("byRef").bool + val name = readName(json("name")) + val params = json("params").arr.map(readParam).toList + val returnType = Option.unless(json("returnType").isNull)(readType(json("returnType"))) + val stmts = json("stmts").arr.map(readStmt).toList + // Only class methods have modifiers + val modifiers = Nil + val namespacedName = + Option.unless(json("namespacedName").isNull)(readName(json("namespacedName"))) + val isClassMethod = false + + PhpMethodDecl( + name, + params, + modifiers, + returnType, + stmts, + returnByRef, + namespacedName, + isClassMethod, + PhpAttributes(json) + ) + end readFunction + + private def readClassMethod(json: Value): PhpMethodDecl = + val modifiers = PhpModifiers.getModifierSet(json) + val returnByRef = json("byRef").bool + val name = readName(json("name")) + val params = json("params").arr.map(readParam).toList + val returnType = Option.unless(json("returnType").isNull)(readType(json("returnType"))) + val stmts = + if json("stmts").isNull then + Nil + else + json("stmts").arr.map(readStmt).toList + + val namespacedName = None // only defined for functions + val isClassMethod = true + + PhpMethodDecl( + name, + params, + modifiers, + returnType, + stmts, + returnByRef, + namespacedName, + isClassMethod, + PhpAttributes(json) + ) + end readClassMethod - private def readUnset(json: Value): PhpUnsetStmt = - val vars = json("vars").arr.map(readExpr).toList + private def readProperty(json: Value): PhpPropertyStmt = + val modifiers = PhpModifiers.getModifierSet(json) + val variables = json("props").arr.map(readPropertyValue).toList + val typeName = Option.unless(json("type").isNull)(readType(json("type"))) - PhpUnsetStmt(vars, PhpAttributes(json)) + PhpPropertyStmt(modifiers, variables, typeName, PhpAttributes(json)) - private def readStatic(json: Value): PhpStaticStmt = - val vars = json("vars").arr.map(readStaticVar).toList + private def readPropertyValue(json: Value): PhpPropertyValue = + val name = readName(json("name")) + val defaultValue = Option.unless(json("default").isNull)(readExpr(json("default"))) - PhpStaticStmt(vars, PhpAttributes(json)) + PhpPropertyValue(name, defaultValue, PhpAttributes(json)) - private def readGlobal(json: Value): PhpGlobalStmt = - val vars = json("vars").arr.map(readExpr).toList + private def readConst(json: Value): PhpConstStmt = + val modifiers = PhpModifiers.getModifierSet(json) - PhpGlobalStmt(vars, PhpAttributes(json)) + val constDeclarations = json("consts").arr.map(readConstDeclaration).toList - private def readUse(json: Value): PhpUseStmt = - val useType = getUseType(json("type").num.toInt) - val uses = json("uses").arr.map(readUseUse(_, useType)).toList + PhpConstStmt(modifiers, constDeclarations, PhpAttributes(json)) - PhpUseStmt(uses, useType, PhpAttributes(json)) + private def readGoto(json: Value): PhpGotoStmt = + val name = readName(json("name")) + PhpGotoStmt(name, PhpAttributes(json)) - private def readGroupUse(json: Value): PhpGroupUseStmt = - val prefix = readName(json("prefix")) - val useType = getUseType(json("type").num.toInt) - val uses = json("uses").arr.map(readUseUse(_, useType)).toList + private def readLabel(json: Value): PhpLabelStmt = + val name = readName(json("name")) + PhpLabelStmt(name, PhpAttributes(json)) - PhpGroupUseStmt(prefix, uses, useType, PhpAttributes(json)) + private def readHaltCompiler(json: Value): PhpHaltCompilerStmt = + // Ignore the remaining text here since it can get quite large (common use case is to separate code from data blob) + PhpHaltCompilerStmt(PhpAttributes(json)) - private def readForeach(json: Value): PhpForeachStmt = - val iterExpr = readExpr(json("expr")) - val keyVar = Option.unless(json("keyVar").isNull)(readExpr(json("keyVar"))) - val valueVar = readExpr(json("valueVar")) - val assignByRef = json("byRef").bool - val stmts = json("stmts").arr.map(readStmt).toList + private def readNamespace(json: Value): PhpNamespaceStmt = + val name = Option.unless(json("name").isNull)(readName(json("name"))) - PhpForeachStmt(iterExpr, keyVar, valueVar, assignByRef, stmts, PhpAttributes(json)) + val stmts = json("stmts") match + case ujson.Null => Nil + case stmts: Arr => stmts.arr.map(readStmt).toList + case unhandled => + logger.debug(s"Unhandled namespace stmts type $unhandled") + ??? - private def readTraitUse(json: Value): PhpTraitUseStmt = - val traits = json("traits").arr.map(readName).toList - val adaptations = json("adaptations").arr.map(readTraitUseAdaptation).toList - PhpTraitUseStmt(traits, adaptations, PhpAttributes(json)) + PhpNamespaceStmt(name, stmts, PhpAttributes(json)) - private def readTraitUseAdaptation(json: Value): PhpTraitUseAdaptation = - json("nodeType").str match - case "Stmt_TraitUseAdaptation_Alias" => readAliasAdaptation(json) - case "Stmt_TraitUseAdaptation_Precedence" => readPrecedenceAdaptation(json) + private def readDeclare(json: Value): PhpDeclareStmt = + val declares = json("declares").arr.map(readDeclareItem).toList + val stmts = Option.unless(json("stmts").isNull)(json("stmts").arr.map(readStmt).toList) - private def readAliasAdaptation(json: Value): PhpAliasAdaptation = - val traitName = Option.unless(json("trait").isNull)(readName(json("trait"))) - val methodName = readName(json("method")) - val newName = Option.unless(json("newName").isNull)(readName(json("newName"))) + PhpDeclareStmt(declares, stmts, PhpAttributes(json)) - val newModifier = json("newModifier") match - case ujson.Null => None - case _ => PhpModifiers.getModifierSet(json, "newModifier").headOption - PhpAliasAdaptation(traitName, methodName, newModifier, newName, PhpAttributes(json)) + private def readUnset(json: Value): PhpUnsetStmt = + val vars = json("vars").arr.map(readExpr).toList - private def readPrecedenceAdaptation(json: Value): PhpPrecedenceAdaptation = - val traitName = readName(json("trait")) - val methodName = readName(json("method")) - val insteadOf = json("insteadof").arr.map(readName).toList + PhpUnsetStmt(vars, PhpAttributes(json)) - PhpPrecedenceAdaptation(traitName, methodName, insteadOf, PhpAttributes(json)) + private def readStatic(json: Value): PhpStaticStmt = + val vars = json("vars").arr.map(readStaticVar).toList - private def readUseUse(json: Value, parentType: PhpUseType): PhpUseUse = - val name = readName(json("name")) - val alias = Option.unless(json("alias").isNull)(readName(json("alias"))) - val useType = - if parentType == PhpUseType.Unknown then - getUseType(json("type").num.toInt) - else - parentType + PhpStaticStmt(vars, PhpAttributes(json)) - PhpUseUse(name, alias, useType, PhpAttributes(json)) - - private def readStaticVar(json: Value): PhpStaticVar = - val variable = readVariable(json("var")) - val defaultValue = Option.unless(json("default").isNull)(readExpr(json("default"))) - - PhpStaticVar(variable, defaultValue, PhpAttributes(json)) + private def readGlobal(json: Value): PhpGlobalStmt = + val vars = json("vars").arr.map(readExpr).toList - private def readDeclareItem(json: Value): PhpDeclareItem = - val key = readName(json("key")) - val value = readExpr(json("value")) + PhpGlobalStmt(vars, PhpAttributes(json)) - PhpDeclareItem(key, value, PhpAttributes(json)) - - private def readConstDeclaration(json: Value): PhpConstDeclaration = - val name = readName(json("name")) - val value = readExpr(json("value")) - val namespacedName = - Option.unless(json.obj.get("namespacedName").isNull)(readName( - json.obj.get("namespacedName") - )) - PhpConstDeclaration(name, value, namespacedName, PhpAttributes(json)) + private def readUse(json: Value): PhpUseStmt = + val useType = getUseType(json("type").num.toInt) + val uses = json("uses").arr.map(readUseUse(_, useType)).toList - private def readParam(json: Value): PhpParam = - val paramType = Option.unless(json("type").isNull)(readType(json("type"))) - PhpParam( - name = json("var")("name").str, - paramType = paramType, - byRef = json("byRef").bool, - isVariadic = json("variadic").bool, - default = json.obj.get("default").filterNot(_.isNull).map(readExpr), - flags = json("flags").num.toInt, - attributes = PhpAttributes(json) - ) + PhpUseStmt(uses, useType, PhpAttributes(json)) - private def readName(json: Value): PhpNameExpr = - json match - case Str(name) => PhpNameExpr(name, PhpAttributes.Empty) + private def readGroupUse(json: Value): PhpGroupUseStmt = + val prefix = readName(json("prefix")) + val useType = getUseType(json("type").num.toInt) + val uses = json("uses").arr.map(readUseUse(_, useType)).toList - case Obj(value) if value.get("nodeType").map(_.str).contains("Name_FullyQualified") => - val name = if value.get("parts").nonEmpty then - value("parts").arr.map(_.str).mkString(NamespaceDelimiter) - else Try(value("name").str).getOrElse("") - PhpNameExpr(name, PhpAttributes(json)) + PhpGroupUseStmt(prefix, uses, useType, PhpAttributes(json)) - case Obj(value) if value.get("nodeType").map(_.str).contains("Name") => - val name = if value.get("parts").nonEmpty then - value("parts").arr.map(_.str).mkString(NamespaceDelimiter) - else Try(value("name").str).getOrElse("") - PhpNameExpr(name, PhpAttributes(json)) + private def readForeach(json: Value): PhpForeachStmt = + val iterExpr = readExpr(json("expr")) + val keyVar = Option.unless(json("keyVar").isNull)(readExpr(json("keyVar"))) + val valueVar = readExpr(json("valueVar")) + val assignByRef = json("byRef").bool + val stmts = json("stmts").arr.map(readStmt).toList - case Obj(value) if value.get("nodeType").map(_.str).contains("Identifier") => - val name = Try(value("name").str).getOrElse("") - if name.nonEmpty then PhpNameExpr(name, PhpAttributes(json)) - else PhpNameExpr("anonymous", PhpAttributes.Empty) + PhpForeachStmt(iterExpr, keyVar, valueVar, assignByRef, stmts, PhpAttributes(json)) - case Obj(value) if value.get("nodeType").map(_.str).contains("VarLikeIdentifier") => - val name = Try(value("name").str).getOrElse("") - if name.nonEmpty then PhpNameExpr(name, PhpAttributes(json)) - else PhpNameExpr("anonymous", PhpAttributes.Empty) + private def readTraitUse(json: Value): PhpTraitUseStmt = + val traits = json("traits").arr.map(readName).toList + val adaptations = json("adaptations").arr.map(readTraitUseAdaptation).toList + PhpTraitUseStmt(traits, adaptations, PhpAttributes(json)) - case arr: Arr => PhpNameExpr(json.toString, PhpAttributes.Empty) + private def readTraitUseAdaptation(json: Value): PhpTraitUseAdaptation = + json("nodeType").str match + case "Stmt_TraitUseAdaptation_Alias" => readAliasAdaptation(json) + case "Stmt_TraitUseAdaptation_Precedence" => readPrecedenceAdaptation(json) - case unhandled => - logger.debug(s"Found unhandled name type $unhandled: $json") - ??? // TODO: other matches are possible? + private def readAliasAdaptation(json: Value): PhpAliasAdaptation = + val traitName = Option.unless(json("trait").isNull)(readName(json("trait"))) + val methodName = readName(json("method")) + val newName = Option.unless(json("newName").isNull)(readName(json("newName"))) - /** One of Identifier, Name, or Complex Type (Nullable, Intersection, or Union) - */ - private def readType(json: Value): PhpNameExpr = - json match - case Obj(value) if value.get("nodeType").map(_.str).contains("NullableType") => - val containedName = readType(value("type")).name - PhpNameExpr(s"?$containedName", attributes = PhpAttributes(json)) + val newModifier = json("newModifier") match + case ujson.Null => None + case _ => PhpModifiers.getModifierSet(json, "newModifier").headOption + PhpAliasAdaptation(traitName, methodName, newModifier, newName, PhpAttributes(json)) - case Obj(value) if value.get("nodeType").map(_.str).contains("IntersectionType") => - val names = value("types").arr.map(readName).map(_.name) - PhpNameExpr(names.mkString("&"), PhpAttributes(json)) + private def readPrecedenceAdaptation(json: Value): PhpPrecedenceAdaptation = + val traitName = readName(json("trait")) + val methodName = readName(json("method")) + val insteadOf = json("insteadof").arr.map(readName).toList - case Obj(value) if value.get("nodeType").map(_.str).contains("UnionType") => - val names = value("types").arr.map(readType).map(_.name) - PhpNameExpr(names.mkString("|"), PhpAttributes(json)) + PhpPrecedenceAdaptation(traitName, methodName, insteadOf, PhpAttributes(json)) - case other => readName(other) + private def readUseUse(json: Value, parentType: PhpUseType): PhpUseUse = + val name = readName(json("name")) + val alias = Option.unless(json("alias").isNull)(readName(json("alias"))) + val useType = + if parentType == PhpUseType.Unknown then + getUseType(json("type").num.toInt) + else + parentType + + PhpUseUse(name, alias, useType, PhpAttributes(json)) + + private def readStaticVar(json: Value): PhpStaticVar = + val variable = readVariable(json("var")) + val defaultValue = Option.unless(json("default").isNull)(readExpr(json("default"))) + + PhpStaticVar(variable, defaultValue, PhpAttributes(json)) + + private def readDeclareItem(json: Value): PhpDeclareItem = + val key = readName(json("key")) + val value = readExpr(json("value")) + + PhpDeclareItem(key, value, PhpAttributes(json)) + + private def readConstDeclaration(json: Value): PhpConstDeclaration = + val name = readName(json("name")) + val value = readExpr(json("value")) + val namespacedName = + Option.unless(json.obj.get("namespacedName").isNull)(readName( + json.obj.get("namespacedName") + )) + PhpConstDeclaration(name, value, namespacedName, PhpAttributes(json)) + + private def readParam(json: Value): PhpParam = + val paramType = Option.unless(json("type").isNull)(readType(json("type"))) + PhpParam( + name = json("var")("name").str, + paramType = paramType, + byRef = json("byRef").bool, + isVariadic = json("variadic").bool, + default = json.obj.get("default").filterNot(_.isNull).map(readExpr), + flags = json("flags").num.toInt, + attributes = PhpAttributes(json) + ) - private def readUnaryOp(json: Value): PhpUnaryOp = - val opType = UnaryOpTypeMap(json("nodeType").str) + private def readName(json: Value): PhpNameExpr = + json match + case Str(name) => PhpNameExpr(name, PhpAttributes.Empty) + + case Obj(value) if value.get("nodeType").map(_.str).contains("Name_FullyQualified") => + val name = if value.get("parts").nonEmpty then + value("parts").arr.map(_.str).mkString(NamespaceDelimiter) + else Try(value("name").str).getOrElse("") + PhpNameExpr(name, PhpAttributes(json)) + + case Obj(value) if value.get("nodeType").map(_.str).contains("Name") => + val name = if value.get("parts").nonEmpty then + value("parts").arr.map(_.str).mkString(NamespaceDelimiter) + else Try(value("name").str).getOrElse("") + PhpNameExpr(name, PhpAttributes(json)) + + case Obj(value) if value.get("nodeType").map(_.str).contains("Identifier") => + val name = Try(value("name").str).getOrElse("") + if name.nonEmpty then PhpNameExpr(name, PhpAttributes(json)) + else PhpNameExpr("anonymous", PhpAttributes.Empty) + + case Obj(value) if value.get("nodeType").map(_.str).contains("VarLikeIdentifier") => + val name = Try(value("name").str).getOrElse("") + if name.nonEmpty then PhpNameExpr(name, PhpAttributes(json)) + else PhpNameExpr("anonymous", PhpAttributes.Empty) + + case arr: Arr => PhpNameExpr(json.toString, PhpAttributes.Empty) + + case unhandled => + logger.debug(s"Found unhandled name type $unhandled: $json") + ??? // TODO: other matches are possible? + + /** One of Identifier, Name, or Complex Type (Nullable, Intersection, or Union) + */ + private def readType(json: Value): PhpNameExpr = + json match + case Obj(value) if value.get("nodeType").map(_.str).contains("NullableType") => + val containedName = readType(value("type")).name + PhpNameExpr(s"?$containedName", attributes = PhpAttributes(json)) + + case Obj(value) if value.get("nodeType").map(_.str).contains("IntersectionType") => + val names = value("types").arr.map(readName).map(_.name) + PhpNameExpr(names.mkString("&"), PhpAttributes(json)) + + case Obj(value) if value.get("nodeType").map(_.str).contains("UnionType") => + val names = value("types").arr.map(readType).map(_.name) + PhpNameExpr(names.mkString("|"), PhpAttributes(json)) + + case other => readName(other) + + private def readUnaryOp(json: Value): PhpUnaryOp = + val opType = UnaryOpTypeMap(json("nodeType").str) + + val expr = + if json.obj.contains("expr") then + readExpr(json.obj("expr")) + else if json.obj.contains("var") then + readExpr(json.obj("var")) + else + throw new UnsupportedOperationException( + s"Expected expr or var field in unary op but found $json" + ) - val expr = - if json.obj.contains("expr") then - readExpr(json.obj("expr")) - else if json.obj.contains("var") then - readExpr(json.obj("var")) - else - throw new UnsupportedOperationException( - s"Expected expr or var field in unary op but found $json" - ) + PhpUnaryOp(opType, expr, PhpAttributes(json)) - PhpUnaryOp(opType, expr, PhpAttributes(json)) + private def readBinaryOp(json: Value): PhpBinaryOp = + val opType = BinaryOpTypeMap(json("nodeType").str) - private def readBinaryOp(json: Value): PhpBinaryOp = - val opType = BinaryOpTypeMap(json("nodeType").str) + val leftExpr = readExpr(json("left")) + val rightExpr = readExpr(json("right")) - val leftExpr = readExpr(json("left")) - val rightExpr = readExpr(json("right")) + PhpBinaryOp(opType, leftExpr, rightExpr, PhpAttributes(json)) - PhpBinaryOp(opType, leftExpr, rightExpr, PhpAttributes(json)) + private def readAssign(json: Value): PhpAssignment = + val nodeType = json("nodeType").str + val opType = AssignTypeMap(nodeType) - private def readAssign(json: Value): PhpAssignment = - val nodeType = json("nodeType").str - val opType = AssignTypeMap(nodeType) + val target = readExpr(json("var")) + val source = readExpr(json("expr")) - val target = readExpr(json("var")) - val source = readExpr(json("expr")) + val isRefAssign = nodeType == "Expr_AssignRef" - val isRefAssign = nodeType == "Expr_AssignRef" + PhpAssignment(opType, target, source, isRefAssign, PhpAttributes(json)) - PhpAssignment(opType, target, source, isRefAssign, PhpAttributes(json)) + private def readCast(json: Value): PhpCast = + val typ = CastTypeMap(json("nodeType").str) + val expr = readExpr(json("expr")) - private def readCast(json: Value): PhpCast = - val typ = CastTypeMap(json("nodeType").str) - val expr = readExpr(json("expr")) + PhpCast(typ, expr, PhpAttributes(json)) - PhpCast(typ, expr, PhpAttributes(json)) + private def readCallArg(json: Value): PhpArgument = + json("nodeType").str match + case "Arg" => + PhpArg( + expr = readExpr(json("value")), + parameterName = json.obj.get("name").filterNot(_.isNull).map(_("name").str), + byRef = json("byRef").bool, + unpack = json("unpack").bool, + attributes = PhpAttributes(json) + ) - private def readCallArg(json: Value): PhpArgument = - json("nodeType").str match - case "Arg" => - PhpArg( - expr = readExpr(json("value")), - parameterName = json.obj.get("name").filterNot(_.isNull).map(_("name").str), - byRef = json("byRef").bool, - unpack = json("unpack").bool, - attributes = PhpAttributes(json) - ) - - case "VariadicPlaceholder" => PhpVariadicPlaceholder(PhpAttributes(json)) + case "VariadicPlaceholder" => PhpVariadicPlaceholder(PhpAttributes(json)) - def fromJson(jsonInput: Value): PhpFile = - readFile(jsonInput) + def fromJson(jsonInput: Value): PhpFile = + readFile(jsonInput) end Domain diff --git a/platform/frontends/php2atom/src/main/scala/io/appthreat/php2atom/parser/PhpParser.scala b/platform/frontends/php2atom/src/main/scala/io/appthreat/php2atom/parser/PhpParser.scala index 9c5d19c1..adad75a3 100644 --- a/platform/frontends/php2atom/src/main/scala/io/appthreat/php2atom/parser/PhpParser.scala +++ b/platform/frontends/php2atom/src/main/scala/io/appthreat/php2atom/parser/PhpParser.scala @@ -11,131 +11,131 @@ import scala.util.{Failure, Success, Try} class PhpParser private (phpParserPath: String, phpIniPath: String): - private val logger = LoggerFactory.getLogger(this.getClass) - - private def phpParseCommand(filename: String): String = - val phpParserCommands = "--with-recovery --resolve-names -P --json-dump" - phpParserPath match - case "phpastgen" => - s"$phpParserPath $phpParserCommands $filename" - case _ => - s"php --php-ini $phpIniPath $phpParserPath $phpParserCommands $filename" - - def parseFile(inputPath: String, phpIniOverride: Option[String]): Option[PhpFile] = - val inputFile = File(inputPath) - val inputFilePath = inputFile.canonicalPath - val inputDirectory = inputFile.parent.canonicalPath - - val command = phpParseCommand(inputFilePath) - - ExternalCommand.run(command, inputDirectory, true) match - case Success(output) => - processParserOutput(output, inputFilePath) - - case Failure(exception) => - logger.debug(s"Failure running php-parser with $command", exception.getMessage) - None - - private def processParserOutput(output: Seq[String], filename: String): Option[PhpFile] = - val maybeJson = - linesToJsonValue(output, filename) - maybeJson.flatMap(jsonValueToPhpFile(_, filename)) - - private def linesToJsonValue(lines: Seq[String], filename: String): Option[ujson.Value] = - if lines.exists(_.startsWith("[")) then - val jsonString = lines.dropWhile(_.charAt(0) != '[').mkString("\n") - Try(Option(ujson.read(jsonString))) match - case Success(Some(value)) => Some(value) - - case Success(None) => - logger.debug(s"Parsing json string for $filename resulted in null return value") - None - - case Failure(exception) => - logger.debug( - s"Parsing json string for $filename failed with exception", - exception - ) - None - else - logger.debug(s"No JSON output for $filename") + private val logger = LoggerFactory.getLogger(this.getClass) + + private def phpParseCommand(filename: String): String = + val phpParserCommands = "--with-recovery --resolve-names -P --json-dump" + phpParserPath match + case "phpastgen" => + s"$phpParserPath $phpParserCommands $filename" + case _ => + s"php --php-ini $phpIniPath $phpParserPath $phpParserCommands $filename" + + def parseFile(inputPath: String, phpIniOverride: Option[String]): Option[PhpFile] = + val inputFile = File(inputPath) + val inputFilePath = inputFile.canonicalPath + val inputDirectory = inputFile.parent.canonicalPath + + val command = phpParseCommand(inputFilePath) + + ExternalCommand.run(command, inputDirectory, true) match + case Success(output) => + processParserOutput(output, inputFilePath) + + case Failure(exception) => + logger.debug(s"Failure running php-parser with $command", exception.getMessage) + None + + private def processParserOutput(output: Seq[String], filename: String): Option[PhpFile] = + val maybeJson = + linesToJsonValue(output, filename) + maybeJson.flatMap(jsonValueToPhpFile(_, filename)) + + private def linesToJsonValue(lines: Seq[String], filename: String): Option[ujson.Value] = + if lines.exists(_.startsWith("[")) then + val jsonString = lines.dropWhile(_.charAt(0) != '[').mkString("\n") + Try(Option(ujson.read(jsonString))) match + case Success(Some(value)) => Some(value) + + case Success(None) => + logger.debug(s"Parsing json string for $filename resulted in null return value") + None + + case Failure(exception) => + logger.debug( + s"Parsing json string for $filename failed with exception", + exception + ) + None + else + logger.debug(s"No JSON output for $filename") + None + + private def jsonValueToPhpFile(json: ujson.Value, filename: String): Option[PhpFile] = + Try(Domain.fromJson(json)) match + case Success(phpFile) => Some(phpFile) + + case Failure(e) => + logger.debug(s"Failed to generate intermediate AST for $filename", e) None - - private def jsonValueToPhpFile(json: ujson.Value, filename: String): Option[PhpFile] = - Try(Domain.fromJson(json)) match - case Success(phpFile) => Some(phpFile) - - case Failure(e) => - logger.debug(s"Failed to generate intermediate AST for $filename", e) - None end PhpParser object PhpParser: - private val logger = LoggerFactory.getLogger(this.getClass()) - - val PhpParserBinEnvVar = "PHP_PARSER_BIN" - - private def defaultPhpIni: String = - val tmpIni = File.newTemporaryFile(suffix = "-php.ini").deleteOnExit() - tmpIni.writeText("memory_limit = -1") - tmpIni.canonicalPath - - private def isPhpAstgenSupported: Boolean = - val result = ExternalCommand.run("phpastgen --help", ".") - result match - case Success(listString) => - true - case Failure(exception) => - false - - private def defaultPhpParserBin: String = - val dir = - Paths.get( - this.getClass.getProtectionDomain.getCodeSource.getLocation.toURI - ).toAbsolutePath.toString - - val fixedDir = new java.io.File(dir.substring(0, dir.indexOf("php2atom"))).toString - - val builtInGen = Paths.get( - fixedDir, - "php2atom", - "vendor", - "bin", - "php-parse" + private val logger = LoggerFactory.getLogger(this.getClass()) + + val PhpParserBinEnvVar = "PHP_PARSER_BIN" + + private def defaultPhpIni: String = + val tmpIni = File.newTemporaryFile(suffix = "-php.ini").deleteOnExit() + tmpIni.writeText("memory_limit = -1") + tmpIni.canonicalPath + + private def isPhpAstgenSupported: Boolean = + val result = ExternalCommand.run("phpastgen --help", ".") + result match + case Success(listString) => + true + case Failure(exception) => + false + + private def defaultPhpParserBin: String = + val dir = + Paths.get( + this.getClass.getProtectionDomain.getCodeSource.getLocation.toURI ).toAbsolutePath.toString - if File(builtInGen).exists() then builtInGen - else "phpastgen" - - private def configOverrideOrDefaultPath( - identifier: String, - maybeOverride: Option[String], - defaultValue: => String - ): Option[String] = - val pathString = maybeOverride match - case Some(overridePath) if overridePath.nonEmpty => - overridePath - case _ => - defaultValue - - File(pathString) match - case file if file.exists() && file.isRegularFile(File.LinkOptions.follow) => - Some(file.canonicalPath) - case _ => Some(defaultValue) - - private def maybePhpParserPath(config: Config): Option[String] = - val phpParserPathOverride = - config.phpParserBin - .orElse(Option(System.getenv(PhpParserBinEnvVar))) - - configOverrideOrDefaultPath("PhpParserBin", phpParserPathOverride, defaultPhpParserBin) - - private def maybePhpIniPath(config: Config): Option[String] = - configOverrideOrDefaultPath("PhpIni", config.phpIni, defaultPhpIni) - - def getParser(config: Config): Option[PhpParser] = - for ( - phpParserPath <- maybePhpParserPath(config); - phpIniPath <- maybePhpIniPath(config) - ) - yield new PhpParser(phpParserPath, phpIniPath) + + val fixedDir = new java.io.File(dir.substring(0, dir.indexOf("php2atom"))).toString + + val builtInGen = Paths.get( + fixedDir, + "php2atom", + "vendor", + "bin", + "php-parse" + ).toAbsolutePath.toString + if File(builtInGen).exists() then builtInGen + else "phpastgen" + + private def configOverrideOrDefaultPath( + identifier: String, + maybeOverride: Option[String], + defaultValue: => String + ): Option[String] = + val pathString = maybeOverride match + case Some(overridePath) if overridePath.nonEmpty => + overridePath + case _ => + defaultValue + + File(pathString) match + case file if file.exists() && file.isRegularFile(File.LinkOptions.follow) => + Some(file.canonicalPath) + case _ => Some(defaultValue) + + private def maybePhpParserPath(config: Config): Option[String] = + val phpParserPathOverride = + config.phpParserBin + .orElse(Option(System.getenv(PhpParserBinEnvVar))) + + configOverrideOrDefaultPath("PhpParserBin", phpParserPathOverride, defaultPhpParserBin) + + private def maybePhpIniPath(config: Config): Option[String] = + configOverrideOrDefaultPath("PhpIni", config.phpIni, defaultPhpIni) + + def getParser(config: Config): Option[PhpParser] = + for ( + phpParserPath <- maybePhpParserPath(config); + phpIniPath <- maybePhpIniPath(config) + ) + yield new PhpParser(phpParserPath, phpIniPath) end PhpParser diff --git a/platform/frontends/php2atom/src/main/scala/io/appthreat/php2atom/passes/AnyTypePass.scala b/platform/frontends/php2atom/src/main/scala/io/appthreat/php2atom/passes/AnyTypePass.scala index 48138e7b..2463d907 100644 --- a/platform/frontends/php2atom/src/main/scala/io/appthreat/php2atom/passes/AnyTypePass.scala +++ b/platform/frontends/php2atom/src/main/scala/io/appthreat/php2atom/passes/AnyTypePass.scala @@ -12,10 +12,10 @@ import io.shiftleft.semanticcpg.language.* // or do it elsewhere. class AnyTypePass(cpg: Cpg) extends ConcurrentWriterCpgPass[AstNode](cpg): - override def generateParts(): Array[AstNode] = - cpg.has(PropertyNames.TYPE_FULL_NAME, PropertyDefaults.TypeFullName).collectAll[ - AstNode - ].toArray + override def generateParts(): Array[AstNode] = + cpg.has(PropertyNames.TYPE_FULL_NAME, PropertyDefaults.TypeFullName).collectAll[ + AstNode + ].toArray - override def runOnPart(diffGraph: DiffGraphBuilder, node: AstNode): Unit = - diffGraph.setNodeProperty(node, PropertyNames.TYPE_FULL_NAME, AstCreator.TypeConstants.Any) + override def runOnPart(diffGraph: DiffGraphBuilder, node: AstNode): Unit = + diffGraph.setNodeProperty(node, PropertyNames.TYPE_FULL_NAME, AstCreator.TypeConstants.Any) diff --git a/platform/frontends/php2atom/src/main/scala/io/appthreat/php2atom/passes/AstCreationPass.scala b/platform/frontends/php2atom/src/main/scala/io/appthreat/php2atom/passes/AstCreationPass.scala index 955cba01..9afeed9c 100644 --- a/platform/frontends/php2atom/src/main/scala/io/appthreat/php2atom/passes/AstCreationPass.scala +++ b/platform/frontends/php2atom/src/main/scala/io/appthreat/php2atom/passes/AstCreationPass.scala @@ -16,35 +16,35 @@ class AstCreationPass(config: Config, cpg: Cpg, parser: PhpParser)(implicit withSchemaValidation: ValidationMode ) extends ConcurrentWriterCpgPass[String](cpg): - private val logger = LoggerFactory.getLogger(this.getClass) - - val PhpSourceFileExtensions: Set[String] = Set(".php") - - override def generateParts(): Array[String] = SourceFiles - .determine( - config.inputPath, - PhpSourceFileExtensions - ) - .toArray - - override def runOnPart(diffGraph: DiffGraphBuilder, filename: String): Unit = - val relativeFilename = if filename == config.inputPath then - File(filename).name - else - File(config.inputPath).relativize(File(filename)).toString - if !config.ignoredFilesRegex.matches( - relativeFilename - ) && !config.defaultIgnoredFilesRegex.exists(_.matches(relativeFilename)) - then - parser.parseFile(filename, config.phpIni) match - case Some(parseResult) => - diffGraph.absorb( - new AstCreator(relativeFilename, parseResult)( - config.schemaValidation - ).createAst() - ) - - case None => - logger.debug(s"Could not parse file $filename. Results will be missing!") - end runOnPart + private val logger = LoggerFactory.getLogger(this.getClass) + + val PhpSourceFileExtensions: Set[String] = Set(".php") + + override def generateParts(): Array[String] = SourceFiles + .determine( + config.inputPath, + PhpSourceFileExtensions + ) + .toArray + + override def runOnPart(diffGraph: DiffGraphBuilder, filename: String): Unit = + val relativeFilename = if filename == config.inputPath then + File(filename).name + else + File(config.inputPath).relativize(File(filename)).toString + if !config.ignoredFilesRegex.matches( + relativeFilename + ) && !config.defaultIgnoredFilesRegex.exists(_.matches(relativeFilename)) + then + parser.parseFile(filename, config.phpIni) match + case Some(parseResult) => + diffGraph.absorb( + new AstCreator(relativeFilename, parseResult)( + config.schemaValidation + ).createAst() + ) + + case None => + logger.debug(s"Could not parse file $filename. Results will be missing!") + end runOnPart end AstCreationPass diff --git a/platform/frontends/php2atom/src/main/scala/io/appthreat/php2atom/passes/AstParentInfoPass.scala b/platform/frontends/php2atom/src/main/scala/io/appthreat/php2atom/passes/AstParentInfoPass.scala index ff6206d8..cb16fe45 100644 --- a/platform/frontends/php2atom/src/main/scala/io/appthreat/php2atom/passes/AstParentInfoPass.scala +++ b/platform/frontends/php2atom/src/main/scala/io/appthreat/php2atom/passes/AstParentInfoPass.scala @@ -8,29 +8,29 @@ import io.shiftleft.semanticcpg.language.* class AstParentInfoPass(cpg: Cpg) extends ConcurrentWriterCpgPass[AstNode](cpg): - override def generateParts(): Array[AstNode] = - (cpg.method ++ cpg.typeDecl).toArray + override def generateParts(): Array[AstNode] = + (cpg.method ++ cpg.typeDecl).toArray - override def runOnPart(diffGraph: DiffGraphBuilder, node: AstNode): Unit = - findParent(node).foreach { parentNode => - val astParentType = parentNode.label - val astParentFullName = parentNode.property(PropertyNames.FULL_NAME) + override def runOnPart(diffGraph: DiffGraphBuilder, node: AstNode): Unit = + findParent(node).foreach { parentNode => + val astParentType = parentNode.label + val astParentFullName = parentNode.property(PropertyNames.FULL_NAME) - diffGraph.setNodeProperty(node, PropertyNames.AST_PARENT_TYPE, astParentType) - diffGraph.setNodeProperty(node, PropertyNames.AST_PARENT_FULL_NAME, astParentFullName) - } + diffGraph.setNodeProperty(node, PropertyNames.AST_PARENT_TYPE, astParentType) + diffGraph.setNodeProperty(node, PropertyNames.AST_PARENT_FULL_NAME, astParentFullName) + } - private def hasValidContainingNodes(nodes: Iterator[AstNode]): Iterator[AstNode] = - nodes.collect { - case m: Method => m - case t: TypeDecl => t - case n: NamespaceBlock => n - } + private def hasValidContainingNodes(nodes: Iterator[AstNode]): Iterator[AstNode] = + nodes.collect { + case m: Method => m + case t: TypeDecl => t + case n: NamespaceBlock => n + } - def findParent(node: AstNode): Option[AstNode] = - node.start - .repeat(_.astParent)( - _.until(hasValidContainingNodes(_)).emit(hasValidContainingNodes(_)) - ) - .find(_ != node) + def findParent(node: AstNode): Option[AstNode] = + node.start + .repeat(_.astParent)( + _.until(hasValidContainingNodes(_)).emit(hasValidContainingNodes(_)) + ) + .find(_ != node) end AstParentInfoPass diff --git a/platform/frontends/php2atom/src/main/scala/io/appthreat/php2atom/passes/ClosureRefPass.scala b/platform/frontends/php2atom/src/main/scala/io/appthreat/php2atom/passes/ClosureRefPass.scala index 17d5e270..99a9dcf4 100644 --- a/platform/frontends/php2atom/src/main/scala/io/appthreat/php2atom/passes/ClosureRefPass.scala +++ b/platform/frontends/php2atom/src/main/scala/io/appthreat/php2atom/passes/ClosureRefPass.scala @@ -10,64 +10,64 @@ import io.shiftleft.codepropertygraph.generated.nodes.AstNode import io.shiftleft.codepropertygraph.generated.nodes.Local class ClosureRefPass(cpg: Cpg) extends ConcurrentWriterCpgPass[ClosureBinding](cpg): - private val logger = LoggerFactory.getLogger(this.getClass) + private val logger = LoggerFactory.getLogger(this.getClass) - override def generateParts(): Array[ClosureBinding] = cpg.all.collectAll[ClosureBinding].toArray + override def generateParts(): Array[ClosureBinding] = cpg.all.collectAll[ClosureBinding].toArray - /** The AstCreator adds closureBindingIds and ClosureBindings for captured locals, but does not - * add the required REF edges from the ClosureBinding to the captured node since the captured - * node may be a Local that is created by the LocalCreationPass and does not exist during AST - * creation. - * - * This pass attempts to find the captured node in the method containing the MethodRef to the - * closure method, since that is the scope in which the closure would have originally been - * created. - */ - override def runOnPart(diffGraph: DiffGraphBuilder, closureBinding: ClosureBinding): Unit = - closureBinding.captureIn.collectAll[MethodRef].toList match - case Nil => - logger.debug( - s"No MethodRef corresponding to closureBinding ${closureBinding.closureBindingId}" - ) + /** The AstCreator adds closureBindingIds and ClosureBindings for captured locals, but does not + * add the required REF edges from the ClosureBinding to the captured node since the captured + * node may be a Local that is created by the LocalCreationPass and does not exist during AST + * creation. + * + * This pass attempts to find the captured node in the method containing the MethodRef to the + * closure method, since that is the scope in which the closure would have originally been + * created. + */ + override def runOnPart(diffGraph: DiffGraphBuilder, closureBinding: ClosureBinding): Unit = + closureBinding.captureIn.collectAll[MethodRef].toList match + case Nil => + logger.debug( + s"No MethodRef corresponding to closureBinding ${closureBinding.closureBindingId}" + ) - case methodRef :: Nil => - addRefToCapturedNode(diffGraph, closureBinding, getMethod(methodRef)) + case methodRef :: Nil => + addRefToCapturedNode(diffGraph, closureBinding, getMethod(methodRef)) - case methodRefs => - logger.debug( - s"Mutliple MethodRefs corresponding to closureBinding ${closureBinding.closureBindingId}" - ) - logger.debug(s"${closureBinding.closureBindingId} MethodRefs = ${methodRefs}") + case methodRefs => + logger.debug( + s"Mutliple MethodRefs corresponding to closureBinding ${closureBinding.closureBindingId}" + ) + logger.debug(s"${closureBinding.closureBindingId} MethodRefs = ${methodRefs}") - private def getMethod(methodRef: MethodRef): Option[Method] = - methodRef.start.repeat(_.astParent)( - _.until(_.isMethod).emit(_.isMethod) - ).isMethod.headOption + private def getMethod(methodRef: MethodRef): Option[Method] = + methodRef.start.repeat(_.astParent)( + _.until(_.isMethod).emit(_.isMethod) + ).isMethod.headOption - private def addRefToCapturedNode( - diffGraph: DiffGraphBuilder, - closureBinding: ClosureBinding, - method: Option[Method] - ): Unit = - method match - case None => - logger.debug( - s"No parent method for methodRef for ${closureBinding.closureBindingId}. REF edge will be missing" - ) + private def addRefToCapturedNode( + diffGraph: DiffGraphBuilder, + closureBinding: ClosureBinding, + method: Option[Method] + ): Unit = + method match + case None => + logger.debug( + s"No parent method for methodRef for ${closureBinding.closureBindingId}. REF edge will be missing" + ) - case Some(method) => - closureBinding.closureOriginalName.foreach { name => - lazy val locals = - method.start.repeat(_.astChildren.filterNot(_.isMethod))( - _.emit(_.isLocal) - ).collectAll[Local] - val maybeCaptured = - method.parameter - .find(_.name == name) - .orElse(locals.find(_.name == name)) + case Some(method) => + closureBinding.closureOriginalName.foreach { name => + lazy val locals = + method.start.repeat(_.astChildren.filterNot(_.isMethod))( + _.emit(_.isLocal) + ).collectAll[Local] + val maybeCaptured = + method.parameter + .find(_.name == name) + .orElse(locals.find(_.name == name)) - maybeCaptured.foreach { captured => - diffGraph.addEdge(closureBinding, captured, EdgeTypes.REF) - } - } + maybeCaptured.foreach { captured => + diffGraph.addEdge(closureBinding, captured, EdgeTypes.REF) + } + } end ClosureRefPass diff --git a/platform/frontends/php2atom/src/main/scala/io/appthreat/php2atom/passes/ConfigFileCreationPass.scala b/platform/frontends/php2atom/src/main/scala/io/appthreat/php2atom/passes/ConfigFileCreationPass.scala index 2986188e..0099073c 100644 --- a/platform/frontends/php2atom/src/main/scala/io/appthreat/php2atom/passes/ConfigFileCreationPass.scala +++ b/platform/frontends/php2atom/src/main/scala/io/appthreat/php2atom/passes/ConfigFileCreationPass.scala @@ -6,16 +6,16 @@ import io.shiftleft.codepropertygraph.Cpg class ConfigFileCreationPass(cpg: Cpg) extends XConfigFileCreationPass(cpg): - override val configFileFilters: List[File => Boolean] = List( - // TOML files - extensionFilter(".toml"), - // INI files - extensionFilter(".ini"), - // YAML files - extensionFilter(".yaml"), - extensionFilter(".lock"), - pathEndFilter("composer.json"), - pathEndFilter("bom.json"), - pathEndFilter(".cdx.json"), - pathEndFilter("chennai.json") - ) + override val configFileFilters: List[File => Boolean] = List( + // TOML files + extensionFilter(".toml"), + // INI files + extensionFilter(".ini"), + // YAML files + extensionFilter(".yaml"), + extensionFilter(".lock"), + pathEndFilter("composer.json"), + pathEndFilter("bom.json"), + pathEndFilter(".cdx.json"), + pathEndFilter("chennai.json") + ) diff --git a/platform/frontends/php2atom/src/main/scala/io/appthreat/php2atom/passes/LocalCreationPass.scala b/platform/frontends/php2atom/src/main/scala/io/appthreat/php2atom/passes/LocalCreationPass.scala index 11300b93..07c1462b 100644 --- a/platform/frontends/php2atom/src/main/scala/io/appthreat/php2atom/passes/LocalCreationPass.scala +++ b/platform/frontends/php2atom/src/main/scala/io/appthreat/php2atom/passes/LocalCreationPass.scala @@ -21,115 +21,115 @@ import io.appthreat.x2cpg.AstNodeBuilder import io.shiftleft.codepropertygraph.generated.PropertyNames object LocalCreationPass: - def allLocalCreationPasses(cpg: Cpg): Iterator[LocalCreationPass[? <: AstNode]] = - Iterator(new NamespaceLocalPass(cpg), new MethodLocalPass(cpg)) + def allLocalCreationPasses(cpg: Cpg): Iterator[LocalCreationPass[? <: AstNode]] = + Iterator(new NamespaceLocalPass(cpg), new MethodLocalPass(cpg)) abstract class LocalCreationPass[ScopeType <: AstNode](cpg: Cpg) extends ConcurrentWriterCpgPass[ScopeType](cpg) with AstNodeBuilder[AstNode, LocalCreationPass[ScopeType]]: - override protected def line(node: AstNode) = node.lineNumber - override protected def column(node: AstNode) = node.columnNumber - override protected def lineEnd(node: AstNode): Option[Integer] = None - override protected def columnEnd(node: AstNode): Option[Integer] = None - override protected def code(node: AstNode): String = node.code - - protected def getIdentifiersInScope(node: AstNode): List[Identifier] = - node match - case identifier: Identifier => identifier :: Nil - case _: TypeDecl | _: Method | _: NamespaceBlock => Nil - case _ if node.astChildren.isEmpty => Nil - case call: Call if call.name == PhpOperators.declareFunc => - // TODO Handle declares properly - // but for now don't change behaviour. - Nil - case _ => node.astChildren.flatMap(getIdentifiersInScope).toList - - protected def localsForIdentifiers( - identifierMap: Map[String, List[Identifier]] - ): List[(NewLocal, List[Identifier])] = - identifierMap - .map { case identifierName -> identifiers => - val code = s"$$$identifierName" - val local = - localNode( - identifiers.head, - identifierName, - code, - AstCreator.TypeConstants.Any, - closureBindingId = None - ) - (local -> identifiers) - } - .toList - .sortBy { case (local, _) => local.name } - - protected def addRefEdges( - diffGraph: DiffGraphBuilder, - localPairs: List[(NewLocal, List[Identifier])] - ): Unit = - localPairs.foreach { case (local, identifiers) => - identifiers.foreach { identifier => - diffGraph.addEdge(identifier, local, EdgeTypes.REF) - } - } - - protected def prependLocalsToBody( - diffGraph: DiffGraphBuilder, - bodyNode: AstNode, - locals: List[NewLocal] - ): Unit = - val originalChildren = bodyNode.astChildren.l - - bodyNode.outE(EdgeTypes.AST).foreach(diffGraph.removeEdge) - - locals.zipWithIndex.foreach { case (local, idx) => - local.order(idx + 1) - } - - val localCount = locals.size - - originalChildren.foreach { node => - diffGraph.setNodeProperty(node, PropertyNames.ORDER, node.order + localCount) - } - - (locals ++ originalChildren).foreach { node => - diffGraph.addEdge(bodyNode, node, EdgeTypes.AST) - } - end prependLocalsToBody - - protected def addLocalsToAst( - diffGraph: DiffGraphBuilder, - bodyNode: AstNode, - excludeIdentifierFn: Identifier => Boolean - ): Unit = - val identifierMap = - getIdentifiersInScope(bodyNode) - .filter(_.refOut.isEmpty) - .filterNot(excludeIdentifierFn) - .groupBy(_.name) - - val localPairs = localsForIdentifiers(identifierMap) - - if localPairs.nonEmpty then - val locals = localPairs.map { case (local, _) => local } - - addRefEdges(diffGraph, localPairs) - prependLocalsToBody(diffGraph, bodyNode, locals) + override protected def line(node: AstNode) = node.lineNumber + override protected def column(node: AstNode) = node.columnNumber + override protected def lineEnd(node: AstNode): Option[Integer] = None + override protected def columnEnd(node: AstNode): Option[Integer] = None + override protected def code(node: AstNode): String = node.code + + protected def getIdentifiersInScope(node: AstNode): List[Identifier] = + node match + case identifier: Identifier => identifier :: Nil + case _: TypeDecl | _: Method | _: NamespaceBlock => Nil + case _ if node.astChildren.isEmpty => Nil + case call: Call if call.name == PhpOperators.declareFunc => + // TODO Handle declares properly + // but for now don't change behaviour. + Nil + case _ => node.astChildren.flatMap(getIdentifiersInScope).toList + + protected def localsForIdentifiers( + identifierMap: Map[String, List[Identifier]] + ): List[(NewLocal, List[Identifier])] = + identifierMap + .map { case identifierName -> identifiers => + val code = s"$$$identifierName" + val local = + localNode( + identifiers.head, + identifierName, + code, + AstCreator.TypeConstants.Any, + closureBindingId = None + ) + (local -> identifiers) + } + .toList + .sortBy { case (local, _) => local.name } + + protected def addRefEdges( + diffGraph: DiffGraphBuilder, + localPairs: List[(NewLocal, List[Identifier])] + ): Unit = + localPairs.foreach { case (local, identifiers) => + identifiers.foreach { identifier => + diffGraph.addEdge(identifier, local, EdgeTypes.REF) + } + } + + protected def prependLocalsToBody( + diffGraph: DiffGraphBuilder, + bodyNode: AstNode, + locals: List[NewLocal] + ): Unit = + val originalChildren = bodyNode.astChildren.l + + bodyNode.outE(EdgeTypes.AST).foreach(diffGraph.removeEdge) + + locals.zipWithIndex.foreach { case (local, idx) => + local.order(idx + 1) + } + + val localCount = locals.size + + originalChildren.foreach { node => + diffGraph.setNodeProperty(node, PropertyNames.ORDER, node.order + localCount) + } + + (locals ++ originalChildren).foreach { node => + diffGraph.addEdge(bodyNode, node, EdgeTypes.AST) + } + end prependLocalsToBody + + protected def addLocalsToAst( + diffGraph: DiffGraphBuilder, + bodyNode: AstNode, + excludeIdentifierFn: Identifier => Boolean + ): Unit = + val identifierMap = + getIdentifiersInScope(bodyNode) + .filter(_.refOut.isEmpty) + .filterNot(excludeIdentifierFn) + .groupBy(_.name) + + val localPairs = localsForIdentifiers(identifierMap) + + if localPairs.nonEmpty then + val locals = localPairs.map { case (local, _) => local } + + addRefEdges(diffGraph, localPairs) + prependLocalsToBody(diffGraph, bodyNode, locals) end LocalCreationPass class NamespaceLocalPass(cpg: Cpg) extends LocalCreationPass[NamespaceBlock](cpg): - override def generateParts(): Array[NamespaceBlock] = cpg.namespaceBlock.toArray + override def generateParts(): Array[NamespaceBlock] = cpg.namespaceBlock.toArray - override def runOnPart(diffGraph: DiffGraphBuilder, namespace: NamespaceBlock): Unit = - addLocalsToAst(diffGraph, namespace, excludeIdentifierFn = _ => false) + override def runOnPart(diffGraph: DiffGraphBuilder, namespace: NamespaceBlock): Unit = + addLocalsToAst(diffGraph, namespace, excludeIdentifierFn = _ => false) class MethodLocalPass(cpg: Cpg) extends LocalCreationPass[Method](cpg): - override def generateParts(): Array[Method] = cpg.method.internal.toArray - - override def runOnPart(diffGraph: DiffGraphBuilder, method: Method): Unit = - val parameters = method.parameter.name.toSet - addLocalsToAst( - diffGraph, - method.body, - excludeIdentifierFn = identifier => parameters.contains(identifier.name) - ) + override def generateParts(): Array[Method] = cpg.method.internal.toArray + + override def runOnPart(diffGraph: DiffGraphBuilder, method: Method): Unit = + val parameters = method.parameter.name.toSet + addLocalsToAst( + diffGraph, + method.body, + excludeIdentifierFn = identifier => parameters.contains(identifier.name) + ) diff --git a/platform/frontends/php2atom/src/main/scala/io/appthreat/php2atom/passes/PhpSetKnownTypes.scala b/platform/frontends/php2atom/src/main/scala/io/appthreat/php2atom/passes/PhpSetKnownTypes.scala index 1c0306c1..94b266e6 100644 --- a/platform/frontends/php2atom/src/main/scala/io/appthreat/php2atom/passes/PhpSetKnownTypes.scala +++ b/platform/frontends/php2atom/src/main/scala/io/appthreat/php2atom/passes/PhpSetKnownTypes.scala @@ -30,50 +30,50 @@ case class KnownFunction( class PhpSetKnownTypesPass(cpg: Cpg, knownTypesFile: Option[JFile] = None) extends ForkJoinParallelCpgPass[KnownFunction](cpg): - private val logger = LoggerFactory.getLogger(getClass) + private val logger = LoggerFactory.getLogger(getClass) - override def generateParts(): Array[KnownFunction] = - /* parse file and return each row as a KnownFunction object */ - val source = knownTypesFile match - case Some(file) => Source.fromFile(file) - case _ => Source.fromResource("known_function_signatures.txt") - val contents = source.getLines().filterNot(_.startsWith("//")) - val arr = contents.flatMap(line => createKnownFunctionFromLine(line)).toArray - source.close - arr + override def generateParts(): Array[KnownFunction] = + /* parse file and return each row as a KnownFunction object */ + val source = knownTypesFile match + case Some(file) => Source.fromFile(file) + case _ => Source.fromResource("known_function_signatures.txt") + val contents = source.getLines().filterNot(_.startsWith("//")) + val arr = contents.flatMap(line => createKnownFunctionFromLine(line)).toArray + source.close + arr - override def runOnPart( - builder: overflowdb.BatchedUpdate.DiffGraphBuilder, - part: KnownFunction - ): Unit = - /* calculate the result of this part - this is done as a concurrent task */ - val builtinMethod = cpg.method.fullNameExact(part.name).l - builtinMethod.foreach(mNode => - setTypes(builder, mNode.methodReturn, part.rTypes) - (mNode.parameter.l.zip(part.pTypes)).map((p, pTypes) => setTypes(builder, p, pTypes)) - ) + override def runOnPart( + builder: overflowdb.BatchedUpdate.DiffGraphBuilder, + part: KnownFunction + ): Unit = + /* calculate the result of this part - this is done as a concurrent task */ + val builtinMethod = cpg.method.fullNameExact(part.name).l + builtinMethod.foreach(mNode => + setTypes(builder, mNode.methodReturn, part.rTypes) + (mNode.parameter.l.zip(part.pTypes)).map((p, pTypes) => setTypes(builder, p, pTypes)) + ) - def createKnownFunctionFromLine(line: String): Option[KnownFunction] = - line.split(";").map(_.strip).toList match - case Nil => None - case name :: Nil => Some(KnownFunction(name)) - case name :: rTypes :: Nil => Some(KnownFunction(name, scanReturnTypes(rTypes))) - case name :: rTypes :: pTypes => - Some(KnownFunction(name, scanReturnTypes(rTypes), scanParamTypes(pTypes))) + def createKnownFunctionFromLine(line: String): Option[KnownFunction] = + line.split(";").map(_.strip).toList match + case Nil => None + case name :: Nil => Some(KnownFunction(name)) + case name :: rTypes :: Nil => Some(KnownFunction(name, scanReturnTypes(rTypes))) + case name :: rTypes :: pTypes => + Some(KnownFunction(name, scanReturnTypes(rTypes), scanParamTypes(pTypes))) - /* From comma separated list of types, create list of types. */ - def scanReturnTypes(rTypesRaw: String): Seq[String] = rTypesRaw.split(",").map(_.strip).toSeq + /* From comma separated list of types, create list of types. */ + def scanReturnTypes(rTypesRaw: String): Seq[String] = rTypesRaw.split(",").map(_.strip).toSeq - /* From a semicolon separated list of parameters, each with a comma separated list of types, - * create a list of lists of types. */ - def scanParamTypes(pTypesRawArr: List[String]): Seq[Seq[String]] = - pTypesRawArr.map(paramTypeRaw => paramTypeRaw.split(",").map(_.strip).toSeq).toSeq + /* From a semicolon separated list of parameters, each with a comma separated list of types, + * create a list of lists of types. */ + def scanParamTypes(pTypesRawArr: List[String]): Seq[Seq[String]] = + pTypesRawArr.map(paramTypeRaw => paramTypeRaw.split(",").map(_.strip).toSeq).toSeq - protected def setTypes( - builder: overflowdb.BatchedUpdate.DiffGraphBuilder, - n: StoredNode, - types: Seq[String] - ): Unit = - if types.size == 1 then builder.setNodeProperty(n, PropertyNames.TYPE_FULL_NAME, types.head) - else builder.setNodeProperty(n, PropertyNames.DYNAMIC_TYPE_HINT_FULL_NAME, types) + protected def setTypes( + builder: overflowdb.BatchedUpdate.DiffGraphBuilder, + n: StoredNode, + types: Seq[String] + ): Unit = + if types.size == 1 then builder.setNodeProperty(n, PropertyNames.TYPE_FULL_NAME, types.head) + else builder.setNodeProperty(n, PropertyNames.DYNAMIC_TYPE_HINT_FULL_NAME, types) end PhpSetKnownTypesPass diff --git a/platform/frontends/php2atom/src/main/scala/io/appthreat/php2atom/utils/ArrayIndexTracker.scala b/platform/frontends/php2atom/src/main/scala/io/appthreat/php2atom/utils/ArrayIndexTracker.scala index b723c71a..674d55df 100644 --- a/platform/frontends/php2atom/src/main/scala/io/appthreat/php2atom/utils/ArrayIndexTracker.scala +++ b/platform/frontends/php2atom/src/main/scala/io/appthreat/php2atom/utils/ArrayIndexTracker.scala @@ -1,16 +1,16 @@ package io.appthreat.php2atom.datastructures class ArrayIndexTracker: - private var currentValue = 0 + private var currentValue = 0 - def next: String = - val nextVal = currentValue - currentValue += 1 - nextVal.toString + def next: String = + val nextVal = currentValue + currentValue += 1 + nextVal.toString - def updateValue(newValue: Int): Unit = - if newValue >= currentValue then - currentValue = newValue + 1 + def updateValue(newValue: Int): Unit = + if newValue >= currentValue then + currentValue = newValue + 1 object ArrayIndexTracker: - def apply(): ArrayIndexTracker = new ArrayIndexTracker + def apply(): ArrayIndexTracker = new ArrayIndexTracker diff --git a/platform/frontends/php2atom/src/main/scala/io/appthreat/php2atom/utils/Scope.scala b/platform/frontends/php2atom/src/main/scala/io/appthreat/php2atom/utils/Scope.scala index 93bed210..a0bd1ca3 100644 --- a/platform/frontends/php2atom/src/main/scala/io/appthreat/php2atom/utils/Scope.scala +++ b/platform/frontends/php2atom/src/main/scala/io/appthreat/php2atom/utils/Scope.scala @@ -18,89 +18,89 @@ import scala.collection.mutable class Scope(implicit nextClosureName: () => String) extends X2CpgScope[String, NewNode, PhpScopeElement]: - private val logger = LoggerFactory.getLogger(this.getClass) + private val logger = LoggerFactory.getLogger(this.getClass) - private var constAndStaticInits: List[mutable.ArrayBuffer[Ast]] = Nil - private var fieldInits: List[mutable.ArrayBuffer[Ast]] = Nil - private val anonymousMethods = mutable.ArrayBuffer[Ast]() + private var constAndStaticInits: List[mutable.ArrayBuffer[Ast]] = Nil + private var fieldInits: List[mutable.ArrayBuffer[Ast]] = Nil + private val anonymousMethods = mutable.ArrayBuffer[Ast]() - def pushNewScope(scopeNode: NewNode): Unit = - scopeNode match - case method: NewMethod => - super.pushNewScope(PhpScopeElement(method)) + def pushNewScope(scopeNode: NewNode): Unit = + scopeNode match + case method: NewMethod => + super.pushNewScope(PhpScopeElement(method)) - case typeDecl: NewTypeDecl => - constAndStaticInits = mutable.ArrayBuffer[Ast]() :: constAndStaticInits - fieldInits = mutable.ArrayBuffer[Ast]() :: fieldInits - super.pushNewScope(PhpScopeElement(typeDecl)) + case typeDecl: NewTypeDecl => + constAndStaticInits = mutable.ArrayBuffer[Ast]() :: constAndStaticInits + fieldInits = mutable.ArrayBuffer[Ast]() :: fieldInits + super.pushNewScope(PhpScopeElement(typeDecl)) - case namespace: NewNamespaceBlock => - super.pushNewScope(PhpScopeElement(namespace)) + case namespace: NewNamespaceBlock => + super.pushNewScope(PhpScopeElement(namespace)) - case invalid => - logger.debug(s"pushNewScope called with invalid node $invalid. Ignoring!") + case invalid => + logger.debug(s"pushNewScope called with invalid node $invalid. Ignoring!") - override def popScope(): Option[PhpScopeElement] = - val scopeNode = super.popScope() + override def popScope(): Option[PhpScopeElement] = + val scopeNode = super.popScope() - scopeNode.map(_.node) match - case Some(_: NewTypeDecl) => - // TODO This is unsafe to catch errors for now - constAndStaticInits = constAndStaticInits.tail - fieldInits = fieldInits.tail + scopeNode.map(_.node) match + case Some(_: NewTypeDecl) => + // TODO This is unsafe to catch errors for now + constAndStaticInits = constAndStaticInits.tail + fieldInits = fieldInits.tail - case _ => // Nothing to do here - scopeNode + case _ => // Nothing to do here + scopeNode - override def addToScope(identifier: String, variable: NewNode): PhpScopeElement = - super.addToScope(identifier, variable) + override def addToScope(identifier: String, variable: NewNode): PhpScopeElement = + super.addToScope(identifier, variable) - def addAnonymousMethod(methodAst: Ast): Unit = anonymousMethods.addOne(methodAst) + def addAnonymousMethod(methodAst: Ast): Unit = anonymousMethods.addOne(methodAst) - def getAndClearAnonymousMethods: List[Ast] = - val methods = anonymousMethods.toList - anonymousMethods.clear() - methods + def getAndClearAnonymousMethods: List[Ast] = + val methods = anonymousMethods.toList + anonymousMethods.clear() + methods - def getEnclosingNamespaceNames: List[String] = - stack.map(_.scopeNode.node).collect { case ns: NewNamespaceBlock => ns.name }.reverse + def getEnclosingNamespaceNames: List[String] = + stack.map(_.scopeNode.node).collect { case ns: NewNamespaceBlock => ns.name }.reverse - def getEnclosingTypeDeclTypeName: Option[String] = - stack.map(_.scopeNode.node).collectFirst { case td: NewTypeDecl => td }.map(_.name) + def getEnclosingTypeDeclTypeName: Option[String] = + stack.map(_.scopeNode.node).collectFirst { case td: NewTypeDecl => td }.map(_.name) - def getEnclosingTypeDeclTypeFullName: Option[String] = - stack.map(_.scopeNode.node).collectFirst { case td: NewTypeDecl => td }.map(_.fullName) + def getEnclosingTypeDeclTypeFullName: Option[String] = + stack.map(_.scopeNode.node).collectFirst { case td: NewTypeDecl => td }.map(_.fullName) - def addConstOrStaticInitToScope(ast: Ast): Unit = - addInitToScope(ast, constAndStaticInits) - def getConstAndStaticInits: List[Ast] = - getInits(constAndStaticInits) + def addConstOrStaticInitToScope(ast: Ast): Unit = + addInitToScope(ast, constAndStaticInits) + def getConstAndStaticInits: List[Ast] = + getInits(constAndStaticInits) - def addFieldInitToScope(ast: Ast): Unit = - addInitToScope(ast, fieldInits) + def addFieldInitToScope(ast: Ast): Unit = + addInitToScope(ast, fieldInits) - def getFieldInits: List[Ast] = - getInits(fieldInits) + def getFieldInits: List[Ast] = + getInits(fieldInits) - def getScopedClosureName: String = - stack.headOption match - case Some(scopeElement) => - scopeElement.scopeNode.getClosureMethodName + def getScopedClosureName: String = + stack.headOption match + case Some(scopeElement) => + scopeElement.scopeNode.getClosureMethodName - case None => - logger.debug( - "BUG: Attempting to get scopedClosureName, but no scope has been push. Defaulting to unscoped" - ) - NameConstants.Closure + case None => + logger.debug( + "BUG: Attempting to get scopedClosureName, but no scope has been push. Defaulting to unscoped" + ) + NameConstants.Closure - private def addInitToScope(ast: Ast, initList: List[mutable.ArrayBuffer[Ast]]): Unit = - // TODO This is unsafe to catch errors for now - initList.head.addOne(ast) + private def addInitToScope(ast: Ast, initList: List[mutable.ArrayBuffer[Ast]]): Unit = + // TODO This is unsafe to catch errors for now + initList.head.addOne(ast) - private def getInits(initList: List[mutable.ArrayBuffer[Ast]]): List[Ast] = - // TODO This is unsafe to catch errors for now - val ret = initList.head.toList - // These ASTs should only be added once to avoid aliasing issues. - initList.head.clear() - ret + private def getInits(initList: List[mutable.ArrayBuffer[Ast]]): List[Ast] = + // TODO This is unsafe to catch errors for now + val ret = initList.head.toList + // These ASTs should only be added once to avoid aliasing issues. + initList.head.clear() + ret end Scope diff --git a/platform/frontends/php2atom/src/main/scala/io/appthreat/php2atom/utils/ScopeElement.scala b/platform/frontends/php2atom/src/main/scala/io/appthreat/php2atom/utils/ScopeElement.scala index 7e7ded98..6b843da0 100644 --- a/platform/frontends/php2atom/src/main/scala/io/appthreat/php2atom/utils/ScopeElement.scala +++ b/platform/frontends/php2atom/src/main/scala/io/appthreat/php2atom/utils/ScopeElement.scala @@ -12,20 +12,18 @@ class PhpScopeElement private (val node: NewNode, scopeName: String)(implicit nextClosureName: () => String ): - def getClosureMethodName: String = - s"$scopeName$InstanceMethodDelimiter${nextClosureName()}" + def getClosureMethodName: String = + s"$scopeName$InstanceMethodDelimiter${nextClosureName()}" object PhpScopeElement: - def apply(method: NewMethod)(implicit nextClosureName: () => String): PhpScopeElement = - new PhpScopeElement(method, method.fullName) + def apply(method: NewMethod)(implicit nextClosureName: () => String): PhpScopeElement = + new PhpScopeElement(method, method.fullName) - def apply(typeDecl: NewTypeDecl)(implicit nextClosureName: () => String): PhpScopeElement = - new PhpScopeElement(typeDecl, typeDecl.fullName) + def apply(typeDecl: NewTypeDecl)(implicit nextClosureName: () => String): PhpScopeElement = + new PhpScopeElement(typeDecl, typeDecl.fullName) - def apply(namespace: NewNamespaceBlock)(implicit - nextClosureName: () => String - ): PhpScopeElement = - new PhpScopeElement(namespace, namespace.fullName) + def apply(namespace: NewNamespaceBlock)(implicit nextClosureName: () => String): PhpScopeElement = + new PhpScopeElement(namespace, namespace.fullName) - def unapply(scopeElement: PhpScopeElement): Option[NewNode] = - Some(scopeElement.node) + def unapply(scopeElement: PhpScopeElement): Option[NewNode] = + Some(scopeElement.node) diff --git a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/AutoIncIndex.scala b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/AutoIncIndex.scala index ef3000c9..16a650a8 100644 --- a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/AutoIncIndex.scala +++ b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/AutoIncIndex.scala @@ -1,7 +1,7 @@ package io.appthreat.pysrc2cpg class AutoIncIndex(private var index: Int): - def getAndInc: Int = - val ret = index - index += 1 - ret + def getAndInc: Int = + val ret = index + index += 1 + ret diff --git a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/CodeToCpg.scala b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/CodeToCpg.scala index b5d4ebad..6d43b453 100644 --- a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/CodeToCpg.scala +++ b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/CodeToCpg.scala @@ -12,28 +12,28 @@ class CodeToCpg( inputProvider: Iterable[InputProvider], schemaValidationMode: ValidationMode ) extends ConcurrentWriterCpgPass[InputProvider](cpg): - import CodeToCpg.logger + import CodeToCpg.logger - override def generateParts(): Array[InputProvider] = inputProvider.toArray + override def generateParts(): Array[InputProvider] = inputProvider.toArray - override def runOnPart(diffGraph: DiffGraphBuilder, inputProvider: InputProvider): Unit = - val inputPair = inputProvider() - try - val parser = new PyParser() - val lineBreakCorrectedCode = inputPair.content.replace("\r\n", "\n").replace("\r", "\n") - val astRoot = parser.parse(lineBreakCorrectedCode) - val nodeToCode = new NodeToCode(lineBreakCorrectedCode) - val astVisitor = new PythonAstVisitor(inputPair.relFileName, nodeToCode, PythonV2AndV3)( - schemaValidationMode - ) - astVisitor.convert(astRoot) + override def runOnPart(diffGraph: DiffGraphBuilder, inputProvider: InputProvider): Unit = + val inputPair = inputProvider() + try + val parser = new PyParser() + val lineBreakCorrectedCode = inputPair.content.replace("\r\n", "\n").replace("\r", "\n") + val astRoot = parser.parse(lineBreakCorrectedCode) + val nodeToCode = new NodeToCode(lineBreakCorrectedCode) + val astVisitor = new PythonAstVisitor(inputPair.relFileName, nodeToCode, PythonV2AndV3)( + schemaValidationMode + ) + astVisitor.convert(astRoot) - diffGraph.absorb(astVisitor.getDiffGraph) - catch - case exception: Throwable => - logger.debug(s"Failed to convert file ${inputPair.relFileName}", exception) - Iterator.empty + diffGraph.absorb(astVisitor.getDiffGraph) + catch + case exception: Throwable => + logger.debug(s"Failed to convert file ${inputPair.relFileName}", exception) + Iterator.empty end CodeToCpg object CodeToCpg: - private val logger = LoggerFactory.getLogger(getClass) + private val logger = LoggerFactory.getLogger(getClass) diff --git a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/ConfigFileCreationPass.scala b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/ConfigFileCreationPass.scala index 8d2aec2c..fff2a776 100644 --- a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/ConfigFileCreationPass.scala +++ b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/ConfigFileCreationPass.scala @@ -7,18 +7,18 @@ import io.shiftleft.codepropertygraph.Cpg class ConfigFileCreationPass(cpg: Cpg, requirementsTxt: String = "requirements.txt") extends XConfigFileCreationPass(cpg): - override val configFileFilters: List[File => Boolean] = List( - // TOML files - extensionFilter(".toml"), - // INI files - extensionFilter(".ini"), - // YAML files - extensionFilter(".yaml"), - extensionFilter(".lock"), - pathEndFilter("bom.json"), - pathEndFilter(".cdx.json"), - pathEndFilter("chennai.json"), - pathEndFilter("setup.cfg"), - // Requirements.txt - pathEndFilter(requirementsTxt) - ) + override val configFileFilters: List[File => Boolean] = List( + // TOML files + extensionFilter(".toml"), + // INI files + extensionFilter(".ini"), + // YAML files + extensionFilter(".yaml"), + extensionFilter(".lock"), + pathEndFilter("bom.json"), + pathEndFilter(".cdx.json"), + pathEndFilter("chennai.json"), + pathEndFilter("setup.cfg"), + // Requirements.txt + pathEndFilter(requirementsTxt) + ) diff --git a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/Constants.scala b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/Constants.scala index d84fbf30..6bcb791b 100644 --- a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/Constants.scala +++ b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/Constants.scala @@ -1,5 +1,5 @@ package io.appthreat.pysrc2cpg object Constants: - val ANY = "ANY" - val GLOBAL_NAMESPACE = "" + val ANY = "ANY" + val GLOBAL_NAMESPACE = "" diff --git a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/ContextStack.scala b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/ContextStack.scala index b2668f19..8db13257 100644 --- a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/ContextStack.scala +++ b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/ContextStack.scala @@ -10,416 +10,416 @@ import org.slf4j.LoggerFactory import scala.collection.mutable object ContextStack: - private val logger = LoggerFactory.getLogger(getClass) + private val logger = LoggerFactory.getLogger(getClass) - def transferLineColInfo(src: NewIdentifier, tgt: NewLocal): Unit = - src.lineNumber match - // If there are multiple occurrences and the local is already set, ignore later updates - case Some(srcLineNo) - if tgt.lineNumber.isEmpty || !tgt.lineNumber.exists(_ < srcLineNo) => - tgt.lineNumber(src.lineNumber) - tgt.columnNumber(src.columnNumber) - case _ => + def transferLineColInfo(src: NewIdentifier, tgt: NewLocal): Unit = + src.lineNumber match + // If there are multiple occurrences and the local is already set, ignore later updates + case Some(srcLineNo) + if tgt.lineNumber.isEmpty || !tgt.lineNumber.exists(_ < srcLineNo) => + tgt.lineNumber(src.lineNumber) + tgt.columnNumber(src.columnNumber) + case _ => class ContextStack: - import ContextStack.logger - - private trait Context: - val astParent: nodes.NewNode - val order: AutoIncIndex - val variables: mutable.Map[String, nodes.NewNode] - var lambdaCounter: Int - - private class MethodContext( - val scopeName: Option[String], - val astParent: nodes.NewNode, - val order: AutoIncIndex, - val isClassBodyMethod: Boolean = false, - val methodBlockNode: Option[nodes.NewBlock] = None, - val methodRefNode: Option[nodes.NewMethodRef] = None, - val variables: mutable.Map[String, nodes.NewNode] = mutable.Map.empty, - val globalVariables: mutable.Set[String] = mutable.Set.empty, - val nonLocalVariables: mutable.Set[String] = mutable.Set.empty, - var lambdaCounter: Int = 0 - ) extends Context {} - - private class ClassContext( - val scopeName: Option[String], - val astParent: nodes.NewNode, - val order: AutoIncIndex, - val variables: mutable.Map[String, nodes.NewNode] = mutable.Map.empty, - var lambdaCounter: Int = 0 - ) extends Context {} - - // Used to represent comprehension variable and exception - // handler context. - // E.g.: [x for x in y] creates an extra context with a - // local variable x which is different from a possible x - // in the surrounding method context. The same applies - // to x in: - // try: - // pass - // except e as x: - // pass - private class SpecialBlockContext( - val astParent: nodes.NewNode, - val order: AutoIncIndex, - val variables: mutable.Map[String, nodes.NewNode] = mutable.Map.empty, - var lambdaCounter: Int = 0 - ) extends Context {} - - private case class VariableReference( - identifier: nodes.NewIdentifier, - memOp: MemoryOperation, - // Context stack as it was when VariableReference - // was created. Context objects are and need to - // shared between different VariableReference - // instances because the changes in the variable - // maps need to be in sync. - stack: List[Context] - ) - - private var stack = List[Context]() - private val variableReferences = mutable.ArrayBuffer.empty[VariableReference] - private var moduleMethodContext = Option.empty[MethodContext] - private var fileNamespaceBlock = Option.empty[nodes.NewNamespaceBlock] - private val fileNamespaceBlockOrder = new AutoIncIndex(1) - - private def push(context: Context): Unit = - stack = context :: stack - - def pushMethod( - scopeName: Option[String], - methodNode: nodes.NewMethod, - methodBlockNode: nodes.NewBlock, - methodRefNode: Option[nodes.NewMethodRef] - ): Unit = - val isClassBodyMethod = stack.headOption.exists(_.isInstanceOf[ClassContext]) - - val methodContext = - new MethodContext( - scopeName, - methodNode, - new AutoIncIndex(1), - isClassBodyMethod, - Some(methodBlockNode), - methodRefNode - ) - if moduleMethodContext.isEmpty then - moduleMethodContext = Some(methodContext) - push(methodContext) - end pushMethod - - def pushClass(scopeName: Option[String], classNode: nodes.NewTypeDecl): Unit = - push(new ClassContext(scopeName, classNode, new AutoIncIndex(1))) - - def pushSpecialContext(): Unit = - val methodContext = findEnclosingMethodContext(stack) - push(new SpecialBlockContext(methodContext.astParent, methodContext.order)) - - def pop(): Unit = - stack = stack.tail - - def setFileNamespaceBlock(namespaceBlock: nodes.NewNamespaceBlock): Unit = - fileNamespaceBlock = Some(namespaceBlock) - - def addVariableReference(identifier: nodes.NewIdentifier, memOp: MemoryOperation): Unit = - variableReferences.append(VariableReference(identifier, memOp, stack)) - - def getAndIncLambdaCounter(): Int = - val result = stack.head.lambdaCounter - stack.head.lambdaCounter += 1 - result - - private def findEnclosingMethodContext(contextStack: List[Context]): MethodContext = - contextStack.find(_.isInstanceOf[MethodContext]).get.asInstanceOf[MethodContext] - - def findEnclosingTypeDecl(): Option[NewNode] = - stack.find(_.isInstanceOf[ClassContext]) match - case Some(classContext: ClassContext) => - Some(classContext.astParent) - case _ => None - - def createIdentifierLinks( - createLocal: (String, Option[String]) => nodes.NewLocal, - createClosureBinding: (String, String) => nodes.NewClosureBinding, - createAstEdge: (nodes.NewNode, nodes.NewNode, Int) => Unit, - createRefEdge: (nodes.NewNode, nodes.NewNode) => Unit, - createCaptureEdge: (nodes.NewNode, nodes.NewNode) => Unit - ): Unit = - // Before we do any linking, we iterate over all variable references and - // create a variable in the module method context for each global variable - // with a store operation on it. - // This is necessary because there might be load/delete operations - // referencing the global variable which are syntactically before the store - // operations. - variableReferences.foreach { case VariableReference(identifier, memOp, contextStack) => - val name = identifier.name - if - memOp == Store && - findEnclosingMethodContext(contextStack).globalVariables.contains(name) && - !moduleMethodContext.get.variables.contains(name) - then - val localNode = createLocal(name, None) - transferLineColInfo(identifier, localNode) - createAstEdge( - localNode, - moduleMethodContext.get.methodBlockNode.get, - moduleMethodContext.get.order.getAndInc - ) - moduleMethodContext.get.variables.put(name, localNode) - } - - // Variable references processing needs to be ordered by context depth in - // order to make sure that variables captured into deeper nested contexts - // are already created. - val sortedVariableRefs = variableReferences.sortBy(_.stack.size) - sortedVariableRefs.foreach { case VariableReference(identifier, memOp, contextStack) => - val name = identifier.name - // Store and delete operations look up variable only in method scope. - // Load operations also look up captured or global variables. - // If a store and load/del happens in the same context, the store must - // come first. Otherwise it is not valid Python, which we assume here. - if memOp == Load then - linkLocalOrCapturing( - createLocal, - createClosureBinding, - createAstEdge, - createRefEdge, - createCaptureEdge, - identifier, - name, - contextStack - ) - else - val enclosingMethodContext = findEnclosingMethodContext(contextStack) - - if - enclosingMethodContext.globalVariables.contains(name) || - enclosingMethodContext.nonLocalVariables.contains(name) - then - linkLocalOrCapturing( - createLocal, - createClosureBinding, - createAstEdge, - createRefEdge, - createCaptureEdge, - identifier, - name, - contextStack - ) - else if memOp == Store then - var variableNode = lookupVariableInMethod(name, contextStack) - if variableNode.isEmpty then - val localNode = createLocal(name, None) - transferLineColInfo(identifier, localNode) - val enclosingMethodContext = findEnclosingMethodContext(contextStack) - createAstEdge( - localNode, - enclosingMethodContext.methodBlockNode.get, - enclosingMethodContext.order.getAndInc - ) - enclosingMethodContext.variables.put(name, localNode) - variableNode = Some(localNode) - createRefEdge(variableNode.get, identifier) - else if memOp == Del then - val variableNode = lookupVariableInMethod(name, contextStack) - variableNode match - case Some(variableNode) => - createRefEdge(variableNode, identifier) - case None => - // When we could not find a matching variable we get here and create a local in - // the method context so that we can link something and fullfil the CPG - // format requirements. - // For example this happens when there are wildcard imports directly into the - // modules namespace. - val localNode = createLocal(name, None) - transferLineColInfo(identifier, localNode) - val methodContext = findEnclosingMethodContext(contextStack) - createAstEdge( - localNode, - methodContext.methodBlockNode.get, - methodContext.order.getAndInc - ) - methodContext.variables.put(name, localNode) - createRefEdge(localNode, identifier) - end if - end if - } - end createIdentifierLinks - - /** Assignments to variables on the module-level may be exported to other modules and behave as - * inter-procedurally global variables. - * @param lhs - * the LHS node of an assignment - */ - def considerAsGlobalVariable(lhs: NewNode): Unit = - lhs match - case n: NewIdentifier - if findEnclosingMethodContext(stack).scopeName.contains("") => - addGlobalVariable(n.name) - case _ => - - /** For module-methods, the variables of this method can be imported into other modules which - * resembles behaviour much like fields/members. This inter-procedural accessibility should be - * marked via the module's type decl node. - */ - def createMemberLinks( - moduleTypeDecl: NewTypeDecl, - astEdgeLinker: (NewNode, NewNode, Int) => Unit - ): Unit = - val globalVarsForEnclMethod = findEnclosingMethodContext(stack).globalVariables - variableReferences - .map(_.identifier) - .filter(i => globalVarsForEnclMethod.contains(i.name)) - .sortBy(i => (i.lineNumber, i.columnNumber)) - .distinctBy(_.name) - .map(i => - NewMember() - .name(i.name) - .typeFullName(Constants.ANY) - .dynamicTypeHintFullName(i.dynamicTypeHintFullName) - .lineNumber(i.lineNumber) - .columnNumber(i.columnNumber) - .code(i.name) + import ContextStack.logger + + private trait Context: + val astParent: nodes.NewNode + val order: AutoIncIndex + val variables: mutable.Map[String, nodes.NewNode] + var lambdaCounter: Int + + private class MethodContext( + val scopeName: Option[String], + val astParent: nodes.NewNode, + val order: AutoIncIndex, + val isClassBodyMethod: Boolean = false, + val methodBlockNode: Option[nodes.NewBlock] = None, + val methodRefNode: Option[nodes.NewMethodRef] = None, + val variables: mutable.Map[String, nodes.NewNode] = mutable.Map.empty, + val globalVariables: mutable.Set[String] = mutable.Set.empty, + val nonLocalVariables: mutable.Set[String] = mutable.Set.empty, + var lambdaCounter: Int = 0 + ) extends Context {} + + private class ClassContext( + val scopeName: Option[String], + val astParent: nodes.NewNode, + val order: AutoIncIndex, + val variables: mutable.Map[String, nodes.NewNode] = mutable.Map.empty, + var lambdaCounter: Int = 0 + ) extends Context {} + + // Used to represent comprehension variable and exception + // handler context. + // E.g.: [x for x in y] creates an extra context with a + // local variable x which is different from a possible x + // in the surrounding method context. The same applies + // to x in: + // try: + // pass + // except e as x: + // pass + private class SpecialBlockContext( + val astParent: nodes.NewNode, + val order: AutoIncIndex, + val variables: mutable.Map[String, nodes.NewNode] = mutable.Map.empty, + var lambdaCounter: Int = 0 + ) extends Context {} + + private case class VariableReference( + identifier: nodes.NewIdentifier, + memOp: MemoryOperation, + // Context stack as it was when VariableReference + // was created. Context objects are and need to + // shared between different VariableReference + // instances because the changes in the variable + // maps need to be in sync. + stack: List[Context] + ) + + private var stack = List[Context]() + private val variableReferences = mutable.ArrayBuffer.empty[VariableReference] + private var moduleMethodContext = Option.empty[MethodContext] + private var fileNamespaceBlock = Option.empty[nodes.NewNamespaceBlock] + private val fileNamespaceBlockOrder = new AutoIncIndex(1) + + private def push(context: Context): Unit = + stack = context :: stack + + def pushMethod( + scopeName: Option[String], + methodNode: nodes.NewMethod, + methodBlockNode: nodes.NewBlock, + methodRefNode: Option[nodes.NewMethodRef] + ): Unit = + val isClassBodyMethod = stack.headOption.exists(_.isInstanceOf[ClassContext]) + + val methodContext = + new MethodContext( + scopeName, + methodNode, + new AutoIncIndex(1), + isClassBodyMethod, + Some(methodBlockNode), + methodRefNode + ) + if moduleMethodContext.isEmpty then + moduleMethodContext = Some(methodContext) + push(methodContext) + end pushMethod + + def pushClass(scopeName: Option[String], classNode: nodes.NewTypeDecl): Unit = + push(new ClassContext(scopeName, classNode, new AutoIncIndex(1))) + + def pushSpecialContext(): Unit = + val methodContext = findEnclosingMethodContext(stack) + push(new SpecialBlockContext(methodContext.astParent, methodContext.order)) + + def pop(): Unit = + stack = stack.tail + + def setFileNamespaceBlock(namespaceBlock: nodes.NewNamespaceBlock): Unit = + fileNamespaceBlock = Some(namespaceBlock) + + def addVariableReference(identifier: nodes.NewIdentifier, memOp: MemoryOperation): Unit = + variableReferences.append(VariableReference(identifier, memOp, stack)) + + def getAndIncLambdaCounter(): Int = + val result = stack.head.lambdaCounter + stack.head.lambdaCounter += 1 + result + + private def findEnclosingMethodContext(contextStack: List[Context]): MethodContext = + contextStack.find(_.isInstanceOf[MethodContext]).get.asInstanceOf[MethodContext] + + def findEnclosingTypeDecl(): Option[NewNode] = + stack.find(_.isInstanceOf[ClassContext]) match + case Some(classContext: ClassContext) => + Some(classContext.astParent) + case _ => None + + def createIdentifierLinks( + createLocal: (String, Option[String]) => nodes.NewLocal, + createClosureBinding: (String, String) => nodes.NewClosureBinding, + createAstEdge: (nodes.NewNode, nodes.NewNode, Int) => Unit, + createRefEdge: (nodes.NewNode, nodes.NewNode) => Unit, + createCaptureEdge: (nodes.NewNode, nodes.NewNode) => Unit + ): Unit = + // Before we do any linking, we iterate over all variable references and + // create a variable in the module method context for each global variable + // with a store operation on it. + // This is necessary because there might be load/delete operations + // referencing the global variable which are syntactically before the store + // operations. + variableReferences.foreach { case VariableReference(identifier, memOp, contextStack) => + val name = identifier.name + if + memOp == Store && + findEnclosingMethodContext(contextStack).globalVariables.contains(name) && + !moduleMethodContext.get.variables.contains(name) + then + val localNode = createLocal(name, None) + transferLineColInfo(identifier, localNode) + createAstEdge( + localNode, + moduleMethodContext.get.methodBlockNode.get, + moduleMethodContext.get.order.getAndInc + ) + moduleMethodContext.get.variables.put(name, localNode) + } + + // Variable references processing needs to be ordered by context depth in + // order to make sure that variables captured into deeper nested contexts + // are already created. + val sortedVariableRefs = variableReferences.sortBy(_.stack.size) + sortedVariableRefs.foreach { case VariableReference(identifier, memOp, contextStack) => + val name = identifier.name + // Store and delete operations look up variable only in method scope. + // Load operations also look up captured or global variables. + // If a store and load/del happens in the same context, the store must + // come first. Otherwise it is not valid Python, which we assume here. + if memOp == Load then + linkLocalOrCapturing( + createLocal, + createClosureBinding, + createAstEdge, + createRefEdge, + createCaptureEdge, + identifier, + name, + contextStack + ) + else + val enclosingMethodContext = findEnclosingMethodContext(contextStack) + + if + enclosingMethodContext.globalVariables.contains(name) || + enclosingMethodContext.nonLocalVariables.contains(name) + then + linkLocalOrCapturing( + createLocal, + createClosureBinding, + createAstEdge, + createRefEdge, + createCaptureEdge, + identifier, + name, + contextStack ) - .zipWithIndex - .foreach { case (m, idx) => astEdgeLinker(m, moduleTypeDecl, idx + 1) } - end createMemberLinks - - private def linkLocalOrCapturing( - createLocal: (String, Option[String]) => NewLocal, - createClosureBinding: (String, String) => NewClosureBinding, - createAstEdge: (NewNode, NewNode, Int) => Unit, - createRefEdge: (NewNode, NewNode) => Unit, - createCaptureEdge: (NewNode, NewNode) => Unit, - identifier: NewIdentifier, - name: String, - contextStack: List[Context] - ): Unit = - var identifierOrClosureBindingToLink: nodes.NewNode = identifier - val stackIt = contextStack.iterator - var contextHasVariable = false - val startContext = contextStack.head - while stackIt.hasNext && !contextHasVariable do - val context = stackIt.next() - - context match - case methodContext: MethodContext => - // Context is only relevant for linking if it is not a class body methods context - // or the identifier/reference itself is from the class body method context. - if !methodContext.isClassBodyMethod || methodContext == startContext then - contextHasVariable = context.variables.contains(name) - - val closureBindingId = - methodContext.astParent.asInstanceOf[NewMethod].fullName + ":" + name - - if !contextHasVariable then - if context != moduleMethodContext.get then - val localNode = createLocal(name, Some(closureBindingId)) - transferLineColInfo(identifier, localNode) - createAstEdge( - localNode, - methodContext.methodBlockNode.get, - methodContext.order.getAndInc - ) - methodContext.variables.put(name, localNode) - else - // When we could not even find a matching variable in the module context we get - // here and create a local so that we can link something and fullfil the CPG - // format requirements. - // For example this happens when there are wildcard imports directly into the - // modules namespace. - val localNode = createLocal(name, None) - transferLineColInfo(identifier, localNode) - createAstEdge( - localNode, - methodContext.methodBlockNode.get, - methodContext.order.getAndInc - ) - methodContext.variables.put(name, localNode) - end if - val localNodeInContext = methodContext.variables(name) - - createRefEdge(localNodeInContext, identifierOrClosureBindingToLink) - - if !contextHasVariable && context != moduleMethodContext.get then - identifierOrClosureBindingToLink = - createClosureBinding(closureBindingId, name) - createCaptureEdge( - identifierOrClosureBindingToLink, - methodContext.methodRefNode.get - ) - case specialBlockContext: SpecialBlockContext => - contextHasVariable = context.variables.contains(name) - if contextHasVariable then - val localNodeInContext = specialBlockContext.variables(name) - createRefEdge(localNodeInContext, identifierOrClosureBindingToLink) - - case _: ClassContext => - assert(context.variables.isEmpty) - // Class context is not relevant for variable linking. - // The context relevant for this is the class method body context. - end match - end while - end linkLocalOrCapturing - - private def lookupVariableInMethod(name: String, stack: List[Context]): Option[nodes.NewNode] = - var variableNode = Option.empty[nodes.NewNode] - - val stackIt = stack.iterator - var lastContextWasMethod = false - while stackIt.hasNext && variableNode.isEmpty && !lastContextWasMethod do - val context = stackIt.next() - variableNode = context.variables.get(name) - lastContextWasMethod = context.isInstanceOf[MethodContext] - variableNode - - def addParameter(parameter: nodes.NewMethodParameterIn): Unit = - assert(stack.head.isInstanceOf[MethodContext]) - stack.head.variables.put(parameter.name, parameter) - - def addSpecialVariable(local: nodes.NewLocal): Unit = - assert(stack.head.isInstanceOf[SpecialBlockContext]) - stack.head.variables.put(local.name, local) - - def addGlobalVariable(name: String): Unit = - findEnclosingMethodContext(stack).globalVariables.add(name) - - def addNonLocalVariable(name: String): Unit = - findEnclosingMethodContext(stack).nonLocalVariables.add(name) - - // Together with the file name this is used to compute full names. - def qualName: String = - stack - .flatMap { - case methodContext: MethodContext => - methodContext.scopeName - case specialBlockContext: SpecialBlockContext => - None - case classContext: ClassContext => - classContext.scopeName - } - .reverse - .mkString(".") - - def astParent: nodes.NewNode = - stack match - case head :: _ => - head.astParent - case Nil => - fileNamespaceBlock.get - - def order: AutoIncIndex = - stack match - case head :: _ => - head.order - case Nil => - fileNamespaceBlockOrder - - def isClassContext: Boolean = - stack.nonEmpty && (stack.head match - case methodContext: MethodContext if methodContext.isClassBodyMethod => true - case _ => false + else if memOp == Store then + var variableNode = lookupVariableInMethod(name, contextStack) + if variableNode.isEmpty then + val localNode = createLocal(name, None) + transferLineColInfo(identifier, localNode) + val enclosingMethodContext = findEnclosingMethodContext(contextStack) + createAstEdge( + localNode, + enclosingMethodContext.methodBlockNode.get, + enclosingMethodContext.order.getAndInc + ) + enclosingMethodContext.variables.put(name, localNode) + variableNode = Some(localNode) + createRefEdge(variableNode.get, identifier) + else if memOp == Del then + val variableNode = lookupVariableInMethod(name, contextStack) + variableNode match + case Some(variableNode) => + createRefEdge(variableNode, identifier) + case None => + // When we could not find a matching variable we get here and create a local in + // the method context so that we can link something and fullfil the CPG + // format requirements. + // For example this happens when there are wildcard imports directly into the + // modules namespace. + val localNode = createLocal(name, None) + transferLineColInfo(identifier, localNode) + val methodContext = findEnclosingMethodContext(contextStack) + createAstEdge( + localNode, + methodContext.methodBlockNode.get, + methodContext.order.getAndInc + ) + methodContext.variables.put(name, localNode) + createRefEdge(localNode, identifier) + end if + end if + } + end createIdentifierLinks + + /** Assignments to variables on the module-level may be exported to other modules and behave as + * inter-procedurally global variables. + * @param lhs + * the LHS node of an assignment + */ + def considerAsGlobalVariable(lhs: NewNode): Unit = + lhs match + case n: NewIdentifier + if findEnclosingMethodContext(stack).scopeName.contains("") => + addGlobalVariable(n.name) + case _ => + + /** For module-methods, the variables of this method can be imported into other modules which + * resembles behaviour much like fields/members. This inter-procedural accessibility should be + * marked via the module's type decl node. + */ + def createMemberLinks( + moduleTypeDecl: NewTypeDecl, + astEdgeLinker: (NewNode, NewNode, Int) => Unit + ): Unit = + val globalVarsForEnclMethod = findEnclosingMethodContext(stack).globalVariables + variableReferences + .map(_.identifier) + .filter(i => globalVarsForEnclMethod.contains(i.name)) + .sortBy(i => (i.lineNumber, i.columnNumber)) + .distinctBy(_.name) + .map(i => + NewMember() + .name(i.name) + .typeFullName(Constants.ANY) + .dynamicTypeHintFullName(i.dynamicTypeHintFullName) + .lineNumber(i.lineNumber) + .columnNumber(i.columnNumber) + .code(i.name) ) + .zipWithIndex + .foreach { case (m, idx) => astEdgeLinker(m, moduleTypeDecl, idx + 1) } + end createMemberLinks + + private def linkLocalOrCapturing( + createLocal: (String, Option[String]) => NewLocal, + createClosureBinding: (String, String) => NewClosureBinding, + createAstEdge: (NewNode, NewNode, Int) => Unit, + createRefEdge: (NewNode, NewNode) => Unit, + createCaptureEdge: (NewNode, NewNode) => Unit, + identifier: NewIdentifier, + name: String, + contextStack: List[Context] + ): Unit = + var identifierOrClosureBindingToLink: nodes.NewNode = identifier + val stackIt = contextStack.iterator + var contextHasVariable = false + val startContext = contextStack.head + while stackIt.hasNext && !contextHasVariable do + val context = stackIt.next() + + context match + case methodContext: MethodContext => + // Context is only relevant for linking if it is not a class body methods context + // or the identifier/reference itself is from the class body method context. + if !methodContext.isClassBodyMethod || methodContext == startContext then + contextHasVariable = context.variables.contains(name) + + val closureBindingId = + methodContext.astParent.asInstanceOf[NewMethod].fullName + ":" + name + + if !contextHasVariable then + if context != moduleMethodContext.get then + val localNode = createLocal(name, Some(closureBindingId)) + transferLineColInfo(identifier, localNode) + createAstEdge( + localNode, + methodContext.methodBlockNode.get, + methodContext.order.getAndInc + ) + methodContext.variables.put(name, localNode) + else + // When we could not even find a matching variable in the module context we get + // here and create a local so that we can link something and fullfil the CPG + // format requirements. + // For example this happens when there are wildcard imports directly into the + // modules namespace. + val localNode = createLocal(name, None) + transferLineColInfo(identifier, localNode) + createAstEdge( + localNode, + methodContext.methodBlockNode.get, + methodContext.order.getAndInc + ) + methodContext.variables.put(name, localNode) + end if + val localNodeInContext = methodContext.variables(name) + + createRefEdge(localNodeInContext, identifierOrClosureBindingToLink) + + if !contextHasVariable && context != moduleMethodContext.get then + identifierOrClosureBindingToLink = + createClosureBinding(closureBindingId, name) + createCaptureEdge( + identifierOrClosureBindingToLink, + methodContext.methodRefNode.get + ) + case specialBlockContext: SpecialBlockContext => + contextHasVariable = context.variables.contains(name) + if contextHasVariable then + val localNodeInContext = specialBlockContext.variables(name) + createRefEdge(localNodeInContext, identifierOrClosureBindingToLink) + + case _: ClassContext => + assert(context.variables.isEmpty) + // Class context is not relevant for variable linking. + // The context relevant for this is the class method body context. + end match + end while + end linkLocalOrCapturing + + private def lookupVariableInMethod(name: String, stack: List[Context]): Option[nodes.NewNode] = + var variableNode = Option.empty[nodes.NewNode] + + val stackIt = stack.iterator + var lastContextWasMethod = false + while stackIt.hasNext && variableNode.isEmpty && !lastContextWasMethod do + val context = stackIt.next() + variableNode = context.variables.get(name) + lastContextWasMethod = context.isInstanceOf[MethodContext] + variableNode + + def addParameter(parameter: nodes.NewMethodParameterIn): Unit = + assert(stack.head.isInstanceOf[MethodContext]) + stack.head.variables.put(parameter.name, parameter) + + def addSpecialVariable(local: nodes.NewLocal): Unit = + assert(stack.head.isInstanceOf[SpecialBlockContext]) + stack.head.variables.put(local.name, local) + + def addGlobalVariable(name: String): Unit = + findEnclosingMethodContext(stack).globalVariables.add(name) + + def addNonLocalVariable(name: String): Unit = + findEnclosingMethodContext(stack).nonLocalVariables.add(name) + + // Together with the file name this is used to compute full names. + def qualName: String = + stack + .flatMap { + case methodContext: MethodContext => + methodContext.scopeName + case specialBlockContext: SpecialBlockContext => + None + case classContext: ClassContext => + classContext.scopeName + } + .reverse + .mkString(".") + + def astParent: nodes.NewNode = + stack match + case head :: _ => + head.astParent + case Nil => + fileNamespaceBlock.get + + def order: AutoIncIndex = + stack match + case head :: _ => + head.order + case Nil => + fileNamespaceBlockOrder + + def isClassContext: Boolean = + stack.nonEmpty && (stack.head match + case methodContext: MethodContext if methodContext.isClassBodyMethod => true + case _ => false + ) end ContextStack diff --git a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/DependenciesFromRequirementsTxtPass.scala b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/DependenciesFromRequirementsTxtPass.scala index 175c8a30..4fc4977e 100644 --- a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/DependenciesFromRequirementsTxtPass.scala +++ b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/DependenciesFromRequirementsTxtPass.scala @@ -21,17 +21,17 @@ Werkzeug==1.0.1 ``` */ class DependenciesFromRequirementsTxtPass(cpg: Cpg) extends CpgPass(cpg): - private val logger: Logger = - LoggerFactory.getLogger(classOf[DependenciesFromRequirementsTxtPass]) - override def run(dstGraph: DiffGraphBuilder): Unit = - cpg.configFile.filter(_.name.endsWith("requirements.txt")).foreach { node => - val lines = node.content.split("\n") - lines.filter(_.matches("^[^=]+==[^=]+$")).foreach { line => - val keyValPattern: Regex = "^([^=]+)==([^=]+)$".r - for patternMatch <- keyValPattern.findAllMatchIn(line) do - val name = patternMatch.group(1) - val version = patternMatch.group(2) - val node = NewDependency().name(name).version(version).dependencyGroupId(name) - dstGraph.addNode(node) - } + private val logger: Logger = + LoggerFactory.getLogger(classOf[DependenciesFromRequirementsTxtPass]) + override def run(dstGraph: DiffGraphBuilder): Unit = + cpg.configFile.filter(_.name.endsWith("requirements.txt")).foreach { node => + val lines = node.content.split("\n") + lines.filter(_.matches("^[^=]+==[^=]+$")).foreach { line => + val keyValPattern: Regex = "^([^=]+)==([^=]+)$".r + for patternMatch <- keyValPattern.findAllMatchIn(line) do + val name = patternMatch.group(1) + val version = patternMatch.group(2) + val node = NewDependency().name(name).version(version).dependencyGroupId(name) + dstGraph.addNode(node) } + } diff --git a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/DynamicTypeHintFullNamePass.scala b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/DynamicTypeHintFullNamePass.scala index b0e848c0..45b0970d 100644 --- a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/DynamicTypeHintFullNamePass.scala +++ b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/DynamicTypeHintFullNamePass.scala @@ -22,99 +22,99 @@ import java.util.regex.{Matcher, Pattern} */ class DynamicTypeHintFullNamePass(cpg: Cpg) extends ForkJoinParallelCpgPass[CfgNode](cpg): - private case class ImportScope(entity: Option[String], alias: Option[String]) + private case class ImportScope(entity: Option[String], alias: Option[String]) - private val fileToImports = cpg.imports.l - .flatMap(imp => imp.call.file.l.map { f => f.name -> imp }) - .groupBy(_._1) - .view - .mapValues(_.map { case (_, imp) => - ImportScope(imp.importedEntity, imp.importedAs) - }) + private val fileToImports = cpg.imports.l + .flatMap(imp => imp.call.file.l.map { f => f.name -> imp }) + .groupBy(_._1) + .view + .mapValues(_.map { case (_, imp) => + ImportScope(imp.importedEntity, imp.importedAs) + }) - override def generateParts(): Array[CfgNode] = - (cpg.methodReturn.filter(x => x.typeFullName != Constants.ANY) ++ cpg.parameter.filter(x => - x.typeFullName != Constants.ANY - )).toArray + override def generateParts(): Array[CfgNode] = + (cpg.methodReturn.filter(x => x.typeFullName != Constants.ANY) ++ cpg.parameter.filter(x => + x.typeFullName != Constants.ANY + )).toArray - override def runOnPart(builder: DiffGraphBuilder, part: CfgNode): Unit = - part match - case x: MethodReturn => runOnMethodReturn(builder, x) - case x: MethodParameterIn => runOnMethodParameter(builder, x) - case _ => + override def runOnPart(builder: DiffGraphBuilder, part: CfgNode): Unit = + part match + case x: MethodReturn => runOnMethodReturn(builder, x) + case x: MethodParameterIn => runOnMethodParameter(builder, x) + case _ => - private def runOnMethodReturn(diffGraph: DiffGraphBuilder, methodReturn: MethodReturn): Unit = - methodReturn.file.foreach { file => - val typeHint = methodReturn.typeFullName - val imports = - fileToImports.getOrElse(file.name, List.empty) ++ methodReturn.method.typeDecl - .map(td => - ImportScope(Option(pythonicTypeNameToImport(td.fullName)), Option(td.name)) - ) - .toList - imports - .filter { x => - // TODO: Handle * imports correctly - x.alias.exists { imported => - typeHint.matches(Pattern.quote(imported) + "(\\..+)*") - } - } - .flatMap(_.entity) - .foreach { importedEntity => - setTypeHints(diffGraph, methodReturn, typeHint, typeHint, importedEntity) - } - } - - private def runOnMethodParameter(diffGraph: DiffGraphBuilder, param: MethodParameterIn): Unit = - param.file.foreach { file => - val typeHint = param.typeFullName - val imports = fileToImports.getOrElse(file.name, List.empty) ++ param.method.typeDecl + private def runOnMethodReturn(diffGraph: DiffGraphBuilder, methodReturn: MethodReturn): Unit = + methodReturn.file.foreach { file => + val typeHint = methodReturn.typeFullName + val imports = + fileToImports.getOrElse(file.name, List.empty) ++ methodReturn.method.typeDecl .map(td => ImportScope(Option(pythonicTypeNameToImport(td.fullName)), Option(td.name)) ) .toList - imports + imports + .filter { x => // TODO: Handle * imports correctly - .filter(_.alias.exists { imported => + x.alias.exists { imported => typeHint.matches(Pattern.quote(imported) + "(\\..+)*") - }) - .foreach { - case ImportScope(Some(importedEntity), Some(importedAs)) => - setTypeHints(diffGraph, param, typeHint, importedAs, importedEntity) - case _ => } - } + } + .flatMap(_.entity) + .foreach { importedEntity => + setTypeHints(diffGraph, methodReturn, typeHint, typeHint, importedEntity) + } + } - private def pythonicTypeNameToImport(fullName: String): String = - fullName.replaceFirst("\\.py:", "").replaceAll(Pattern.quote(File.separator), ".") + private def runOnMethodParameter(diffGraph: DiffGraphBuilder, param: MethodParameterIn): Unit = + param.file.foreach { file => + val typeHint = param.typeFullName + val imports = fileToImports.getOrElse(file.name, List.empty) ++ param.method.typeDecl + .map(td => + ImportScope(Option(pythonicTypeNameToImport(td.fullName)), Option(td.name)) + ) + .toList + imports + // TODO: Handle * imports correctly + .filter(_.alias.exists { imported => + typeHint.matches(Pattern.quote(imported) + "(\\..+)*") + }) + .foreach { + case ImportScope(Some(importedEntity), Some(importedAs)) => + setTypeHints(diffGraph, param, typeHint, importedAs, importedEntity) + case _ => + } + } - private def setTypeHints( - diffGraph: BatchedUpdate.DiffGraphBuilder, - node: StoredNode, - typeHint: String, - alias: String, - importedEntity: String - ) = - val importFullPath = ImportStringHandling.combinedPath(importedEntity, typeHint) - val typeHintFullName = typeHint.replaceFirst(Pattern.quote(alias), importedEntity) - val typeFilePath = - typeHintFullName.replaceAll("\\.", Matcher.quoteReplacement(File.separator)) - val pythonicTypeFullName = importFullPath.split("\\.").lastOption match - case Some(typeName) => - typeFilePath.stripSuffix(s"${File.separator}$typeName").concat( - s".py:.$typeName" - ) - case None => typeHintFullName - cpg.typeDecl.fullName(s".*${Pattern.quote(pythonicTypeFullName)}").l match - case xs if xs.sizeIs == 1 => - diffGraph.setNodeProperty(node, PropertyNames.TYPE_FULL_NAME, xs.fullName.head) - case xs if xs.nonEmpty => - diffGraph.setNodeProperty( - node, - PropertyNames.DYNAMIC_TYPE_HINT_FULL_NAME, - xs.fullName.toSeq - ) - case _ => - diffGraph.setNodeProperty(node, PropertyNames.TYPE_FULL_NAME, pythonicTypeFullName) - end setTypeHints + private def pythonicTypeNameToImport(fullName: String): String = + fullName.replaceFirst("\\.py:", "").replaceAll(Pattern.quote(File.separator), ".") + + private def setTypeHints( + diffGraph: BatchedUpdate.DiffGraphBuilder, + node: StoredNode, + typeHint: String, + alias: String, + importedEntity: String + ) = + val importFullPath = ImportStringHandling.combinedPath(importedEntity, typeHint) + val typeHintFullName = typeHint.replaceFirst(Pattern.quote(alias), importedEntity) + val typeFilePath = + typeHintFullName.replaceAll("\\.", Matcher.quoteReplacement(File.separator)) + val pythonicTypeFullName = importFullPath.split("\\.").lastOption match + case Some(typeName) => + typeFilePath.stripSuffix(s"${File.separator}$typeName").concat( + s".py:.$typeName" + ) + case None => typeHintFullName + cpg.typeDecl.fullName(s".*${Pattern.quote(pythonicTypeFullName)}").l match + case xs if xs.sizeIs == 1 => + diffGraph.setNodeProperty(node, PropertyNames.TYPE_FULL_NAME, xs.fullName.head) + case xs if xs.nonEmpty => + diffGraph.setNodeProperty( + node, + PropertyNames.DYNAMIC_TYPE_HINT_FULL_NAME, + xs.fullName.toSeq + ) + case _ => + diffGraph.setNodeProperty(node, PropertyNames.TYPE_FULL_NAME, pythonicTypeFullName) + end setTypeHints end DynamicTypeHintFullNamePass diff --git a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/EdgeBuilder.scala b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/EdgeBuilder.scala index 5d2a6b6e..80d306e4 100644 --- a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/EdgeBuilder.scala +++ b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/EdgeBuilder.scala @@ -26,81 +26,81 @@ import io.shiftleft.codepropertygraph.generated.{EdgeTypes, nodes} import overflowdb.BatchedUpdate.DiffGraphBuilder class EdgeBuilder(diffGraph: DiffGraphBuilder): - def astEdge(dstNode: nodes.NewNode, srcNode: nodes.NewNode, order: Int): Unit = - diffGraph.addEdge(srcNode, dstNode, EdgeTypes.AST) - addOrder(dstNode, order) + def astEdge(dstNode: nodes.NewNode, srcNode: nodes.NewNode, order: Int): Unit = + diffGraph.addEdge(srcNode, dstNode, EdgeTypes.AST) + addOrder(dstNode, order) - def argumentEdge(dstNode: nodes.NewNode, srcNode: nodes.NewNode, argIndex: Int): Unit = - diffGraph.addEdge(srcNode, dstNode, EdgeTypes.ARGUMENT) - addArgumentIndex(dstNode, argIndex) + def argumentEdge(dstNode: nodes.NewNode, srcNode: nodes.NewNode, argIndex: Int): Unit = + diffGraph.addEdge(srcNode, dstNode, EdgeTypes.ARGUMENT) + addArgumentIndex(dstNode, argIndex) - def argumentEdge(dstNode: nodes.NewNode, srcNode: nodes.NewNode, argName: String): Unit = - diffGraph.addEdge(srcNode, dstNode, EdgeTypes.ARGUMENT) - // We need to fill something according to the CPG spec. But the spec also says that argument - // index is ignored if argument name is provided. So we just put -1. - addArgumentIndex(dstNode, -1) - addArgumentName(dstNode, argName) + def argumentEdge(dstNode: nodes.NewNode, srcNode: nodes.NewNode, argName: String): Unit = + diffGraph.addEdge(srcNode, dstNode, EdgeTypes.ARGUMENT) + // We need to fill something according to the CPG spec. But the spec also says that argument + // index is ignored if argument name is provided. So we just put -1. + addArgumentIndex(dstNode, -1) + addArgumentName(dstNode, argName) - def receiverEdge(dstNode: nodes.NewNode, srcNode: nodes.NewNode): Unit = - diffGraph.addEdge(srcNode, dstNode, EdgeTypes.RECEIVER) + def receiverEdge(dstNode: nodes.NewNode, srcNode: nodes.NewNode): Unit = + diffGraph.addEdge(srcNode, dstNode, EdgeTypes.RECEIVER) - def conditionEdge(dstNode: nodes.NewNode, srcNode: nodes.NewNode): Unit = - diffGraph.addEdge(srcNode, dstNode, EdgeTypes.CONDITION) + def conditionEdge(dstNode: nodes.NewNode, srcNode: nodes.NewNode): Unit = + diffGraph.addEdge(srcNode, dstNode, EdgeTypes.CONDITION) - def refEdge(dstNode: nodes.NewNode, srcNode: nodes.NewNode): Unit = - diffGraph.addEdge(srcNode, dstNode, EdgeTypes.REF) + def refEdge(dstNode: nodes.NewNode, srcNode: nodes.NewNode): Unit = + diffGraph.addEdge(srcNode, dstNode, EdgeTypes.REF) - def captureEdge(dstNode: nodes.NewNode, srcNode: nodes.NewNode): Unit = - diffGraph.addEdge(srcNode, dstNode, EdgeTypes.CAPTURE) + def captureEdge(dstNode: nodes.NewNode, srcNode: nodes.NewNode): Unit = + diffGraph.addEdge(srcNode, dstNode, EdgeTypes.CAPTURE) - def bindsEdge(dstNode: nodes.NewNode, srcNode: nodes.NewNode): Unit = - diffGraph.addEdge(srcNode, dstNode, EdgeTypes.BINDS) + def bindsEdge(dstNode: nodes.NewNode, srcNode: nodes.NewNode): Unit = + diffGraph.addEdge(srcNode, dstNode, EdgeTypes.BINDS) - private def addOrder(node: nodes.NewNode, order: Int): Unit = node match - case n: NewTypeDecl => n.order = order - case n: NewBlock => n.order = order - case n: NewCall => n.order = order - case n: NewFieldIdentifier => n.order = order - case n: NewFile => n.order = order - case n: NewIdentifier => n.order = order - case n: NewLocal => n.order = order - case n: NewMethod => n.order = order - case n: NewMethodParameterIn => n.order = order - case n: NewMethodRef => n.order = order - case n: NewNamespaceBlock => n.order = order - case n: NewTypeRef => n.order = order - case n: NewUnknown => n.order = order - case n: NewModifier => n.order = order - case n: NewMethodReturn => n.order = order - case n: NewMember => n.order = order - case n: NewControlStructure => n.order = order - case n: NewLiteral => n.order = order - case n: NewReturn => n.order = order - case n: NewJumpTarget => n.order = order + private def addOrder(node: nodes.NewNode, order: Int): Unit = node match + case n: NewTypeDecl => n.order = order + case n: NewBlock => n.order = order + case n: NewCall => n.order = order + case n: NewFieldIdentifier => n.order = order + case n: NewFile => n.order = order + case n: NewIdentifier => n.order = order + case n: NewLocal => n.order = order + case n: NewMethod => n.order = order + case n: NewMethodParameterIn => n.order = order + case n: NewMethodRef => n.order = order + case n: NewNamespaceBlock => n.order = order + case n: NewTypeRef => n.order = order + case n: NewUnknown => n.order = order + case n: NewModifier => n.order = order + case n: NewMethodReturn => n.order = order + case n: NewMember => n.order = order + case n: NewControlStructure => n.order = order + case n: NewLiteral => n.order = order + case n: NewReturn => n.order = order + case n: NewJumpTarget => n.order = order - private def addArgumentIndex(node: nodes.NewNode, argIndex: Int): Unit = node match - case n: NewBlock => n.argumentIndex = argIndex - case n: NewCall => n.argumentIndex = argIndex - case n: NewFieldIdentifier => n.argumentIndex = argIndex - case n: NewIdentifier => n.argumentIndex = argIndex - case n: NewMethodRef => n.argumentIndex = argIndex - case n: NewTypeRef => n.argumentIndex = argIndex - case n: NewUnknown => n.argumentIndex = argIndex - case n: NewControlStructure => n.argumentIndex = argIndex - case n: NewLiteral => n.argumentIndex = argIndex - case n: NewReturn => n.argumentIndex = argIndex + private def addArgumentIndex(node: nodes.NewNode, argIndex: Int): Unit = node match + case n: NewBlock => n.argumentIndex = argIndex + case n: NewCall => n.argumentIndex = argIndex + case n: NewFieldIdentifier => n.argumentIndex = argIndex + case n: NewIdentifier => n.argumentIndex = argIndex + case n: NewMethodRef => n.argumentIndex = argIndex + case n: NewTypeRef => n.argumentIndex = argIndex + case n: NewUnknown => n.argumentIndex = argIndex + case n: NewControlStructure => n.argumentIndex = argIndex + case n: NewLiteral => n.argumentIndex = argIndex + case n: NewReturn => n.argumentIndex = argIndex - private def addArgumentName(node: nodes.NewNode, argName: String): Unit = - val someArgName = Some(argName) - node match - case n: NewBlock => n.argumentName = someArgName - case n: NewCall => n.argumentName = someArgName - case n: NewFieldIdentifier => n.argumentName = someArgName - case n: NewIdentifier => n.argumentName = someArgName - case n: NewMethodRef => n.argumentName = someArgName - case n: NewTypeRef => n.argumentName = someArgName - case n: NewUnknown => n.argumentName = someArgName - case n: NewControlStructure => n.argumentName = someArgName - case n: NewLiteral => n.argumentName = someArgName - case n: NewReturn => n.argumentName = someArgName + private def addArgumentName(node: nodes.NewNode, argName: String): Unit = + val someArgName = Some(argName) + node match + case n: NewBlock => n.argumentName = someArgName + case n: NewCall => n.argumentName = someArgName + case n: NewFieldIdentifier => n.argumentName = someArgName + case n: NewIdentifier => n.argumentName = someArgName + case n: NewMethodRef => n.argumentName = someArgName + case n: NewTypeRef => n.argumentName = someArgName + case n: NewUnknown => n.argumentName = someArgName + case n: NewControlStructure => n.argumentName = someArgName + case n: NewLiteral => n.argumentName = someArgName + case n: NewReturn => n.argumentName = someArgName end EdgeBuilder diff --git a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/ImportResolverPass.scala b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/ImportResolverPass.scala index f0268a15..16345302 100644 --- a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/ImportResolverPass.scala +++ b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/ImportResolverPass.scala @@ -12,160 +12,160 @@ import java.util.regex.{Matcher, Pattern} class ImportResolverPass(cpg: Cpg) extends XImportResolverPass(cpg): - private lazy val root = cpg.metaData.root.headOption.getOrElse("").stripSuffix(JFile.separator) - - override protected def optionalResolveImport( - fileName: String, - importCall: Call, - importedEntity: String, - importedAs: String, - diffGraph: DiffGraphBuilder - ): Unit = - val (namespace, entityName) = if importedEntity.contains(".") then - val splitName = importedEntity.split('.').toSeq - val namespace = importedEntity.stripSuffix(s".${splitName.last}") - (relativizeNamespace(namespace, fileName), splitName.last) + private lazy val root = cpg.metaData.root.headOption.getOrElse("").stripSuffix(JFile.separator) + + override protected def optionalResolveImport( + fileName: String, + importCall: Call, + importedEntity: String, + importedAs: String, + diffGraph: DiffGraphBuilder + ): Unit = + val (namespace, entityName) = if importedEntity.contains(".") then + val splitName = importedEntity.split('.').toSeq + val namespace = importedEntity.stripSuffix(s".${splitName.last}") + (relativizeNamespace(namespace, fileName), splitName.last) + else + val currDir = BFile(root) / fileName match + case x if x.isDirectory => x + case x => x.parent + + val relCurrDir = currDir.pathAsString.stripPrefix(root).stripPrefix(JFile.separator) + + (relCurrDir, importedEntity) + + resolveEntities(namespace, entityName, importedAs).foreach(x => + resolvedImportToTag(x, importCall, diffGraph) + ) + end optionalResolveImport + + private def relativizeNamespace(path: String, fileName: String): String = + if path.startsWith(".") then + // TODO: pysrc2cpg does not link files to the correct namespace nodes + val sep = Matcher.quoteReplacement(JFile.separator) + // The below gives us the full path of the relative "." + val relativeNamespace = + if fileName.contains(JFile.separator) then + fileName.substring(0, fileName.lastIndexOf(JFile.separator)).replaceAll( + sep, + "." + ) + else "" + (if path.length > 1 then relativeNamespace + path.replaceAll(sep, ".") + else relativeNamespace).stripPrefix(".") + else path + + /** For an import - given by its module path and the name of the imported function or module - + * determine the possible callee names. + * + * @param path + * the module path. + * @param expEntity + * the name of the imported entity. This could be a function, module, or variable/field. + * @param alias + * how the imported entity is named. + * @return + * the possible callee names + */ + private def resolveEntities( + path: String, + expEntity: String, + alias: String + ): Set[ResolvedImport] = + + implicit class ResolvedNodeExt(val traversal: Seq[String]): + def toResolvedImport(cpg: Cpg): Seq[ResolvedImport] = + val resolvedEntities = + traversal.flatMap(x => + cpg.typeDecl.fullNameExact(x) ++ cpg.method.fullNameExact(x) + ).collect { + case x: Method => ResolvedMethod(x.fullName, alias) + case x: TypeDecl => ResolvedTypeDecl(x.fullName) + } + if resolvedEntities.isEmpty then + traversal.filterNot(_.contains("__init__.py")).map(x => UnknownImport(x)) else - val currDir = BFile(root) / fileName match - case x if x.isDirectory => x - case x => x.parent + resolvedEntities - val relCurrDir = currDir.pathAsString.stripPrefix(root).stripPrefix(JFile.separator) + implicit class CalleeAsInitExt(val name: String): + def asInit: String = if name.contains("__init__.py") then name + else name.replace(".py", s"${JFile.separator}__init__.py") - (relCurrDir, importedEntity) + def withInit: Seq[String] = Seq(name, name.asInit) - resolveEntities(namespace, entityName, importedAs).foreach(x => - resolvedImportToTag(x, importCall, diffGraph) + val pathSep = "." + val sep = Matcher.quoteReplacement(JFile.separator) + val isMaybeConstructor = + expEntity.split("\\.").lastOption.exists(s => s.nonEmpty && s.charAt(0).isUpper) + + lazy val membersMatchingImports: List[(TypeDecl, Member)] = cpg.typeDecl + .fullName(s".*${Pattern.quote(path)}.*") + .flatMap(t => + t.member.nameExact(expEntity).headOption match + case Some(member) => Option((t, member)) + case None => None ) - end optionalResolveImport - - private def relativizeNamespace(path: String, fileName: String): String = - if path.startsWith(".") then - // TODO: pysrc2cpg does not link files to the correct namespace nodes - val sep = Matcher.quoteReplacement(JFile.separator) - // The below gives us the full path of the relative "." - val relativeNamespace = - if fileName.contains(JFile.separator) then - fileName.substring(0, fileName.lastIndexOf(JFile.separator)).replaceAll( - sep, - "." - ) - else "" - (if path.length > 1 then relativeNamespace + path.replaceAll(sep, ".") - else relativeNamespace).stripPrefix(".") - else path - - /** For an import - given by its module path and the name of the imported function or module - - * determine the possible callee names. - * - * @param path - * the module path. - * @param expEntity - * the name of the imported entity. This could be a function, module, or variable/field. - * @param alias - * how the imported entity is named. - * @return - * the possible callee names - */ - private def resolveEntities( - path: String, - expEntity: String, - alias: String - ): Set[ResolvedImport] = - - implicit class ResolvedNodeExt(val traversal: Seq[String]): - def toResolvedImport(cpg: Cpg): Seq[ResolvedImport] = - val resolvedEntities = - traversal.flatMap(x => - cpg.typeDecl.fullNameExact(x) ++ cpg.method.fullNameExact(x) - ).collect { - case x: Method => ResolvedMethod(x.fullName, alias) - case x: TypeDecl => ResolvedTypeDecl(x.fullName) - } - if resolvedEntities.isEmpty then - traversal.filterNot(_.contains("__init__.py")).map(x => UnknownImport(x)) - else - resolvedEntities - - implicit class CalleeAsInitExt(val name: String): - def asInit: String = if name.contains("__init__.py") then name - else name.replace(".py", s"${JFile.separator}__init__.py") - - def withInit: Seq[String] = Seq(name, name.asInit) - - val pathSep = "." - val sep = Matcher.quoteReplacement(JFile.separator) - val isMaybeConstructor = - expEntity.split("\\.").lastOption.exists(s => s.nonEmpty && s.charAt(0).isUpper) - - lazy val membersMatchingImports: List[(TypeDecl, Member)] = cpg.typeDecl - .fullName(s".*${Pattern.quote(path)}.*") - .flatMap(t => - t.member.nameExact(expEntity).headOption match - case Some(member) => Option((t, member)) - case None => None - ) - .toList - - (path match - case "" if expEntity.contains(".") => - // Case 1: Qualified path: import foo.bar => (bar.py or bar/__init__.py) - val splitFunc = expEntity.split("\\.") - val name = splitFunc.tail.mkString(".") - s"${splitFunc(0)}.py:$pathSep$name".withInit.toResolvedImport(cpg) - case "" => - // Case 2: import of a module: import foo => (foo.py or foo/__init__.py) - s"$expEntity.py:".withInit.toResolvedImport(cpg) - case _ if membersMatchingImports.nonEmpty => - // Case 3: import of a variable: from api import db => (api.py or foo.__init__.py) @ identifier(db) - membersMatchingImports.map { - case (t, m) if t.method.nameExact(m.name).nonEmpty => - ResolvedMethod(t.method.nameExact(m.name).fullName.head, alias) - case (t, m) - if t.astSiblings.isMethod.fullNameExact( - t.fullName - ).ast.isTypeDecl.nameExact(m.name).nonEmpty => - ResolvedTypeDecl( - t.astSiblings.isMethod.fullNameExact(t.fullName).ast.isTypeDecl.nameExact( - m.name - ).fullName.head - ) - case (t, m) => ResolvedMember(t.fullName, m.name) - } - case _ => - // Case 4: Import from module using alias, e.g. import bar from foo as faz - val fileOrDir = BFile(codeRoot) / path - val pyFile = BFile(codeRoot) / s"$path.py" - fileOrDir match - case f if f.isDirectory && !pyFile.exists => - Seq( - s"${path.replaceAll("\\.", sep)}${java.io.File.separator}$expEntity.py:" - ).toResolvedImport(cpg) - case f if f.isDirectory && (f / s"$expEntity.py").exists => - Seq( - s"${(f / s"$expEntity.py").pathAsString.stripPrefix(codeRoot)}:" - ).toResolvedImport(cpg) - case _ => - s"${path.replaceAll("\\.", sep)}.py:$pathSep$expEntity".withInit.toResolvedImport( - cpg - ) - ).flatMap { - // If we import the constructor, we also import the type - case x: ResolvedMethod if isMaybeConstructor => + .toList + + (path match + case "" if expEntity.contains(".") => + // Case 1: Qualified path: import foo.bar => (bar.py or bar/__init__.py) + val splitFunc = expEntity.split("\\.") + val name = splitFunc.tail.mkString(".") + s"${splitFunc(0)}.py:$pathSep$name".withInit.toResolvedImport(cpg) + case "" => + // Case 2: import of a module: import foo => (foo.py or foo/__init__.py) + s"$expEntity.py:".withInit.toResolvedImport(cpg) + case _ if membersMatchingImports.nonEmpty => + // Case 3: import of a variable: from api import db => (api.py or foo.__init__.py) @ identifier(db) + membersMatchingImports.map { + case (t, m) if t.method.nameExact(m.name).nonEmpty => + ResolvedMethod(t.method.nameExact(m.name).fullName.head, alias) + case (t, m) + if t.astSiblings.isMethod.fullNameExact( + t.fullName + ).ast.isTypeDecl.nameExact(m.name).nonEmpty => + ResolvedTypeDecl( + t.astSiblings.isMethod.fullNameExact(t.fullName).ast.isTypeDecl.nameExact( + m.name + ).fullName.head + ) + case (t, m) => ResolvedMember(t.fullName, m.name) + } + case _ => + // Case 4: Import from module using alias, e.g. import bar from foo as faz + val fileOrDir = BFile(codeRoot) / path + val pyFile = BFile(codeRoot) / s"$path.py" + fileOrDir match + case f if f.isDirectory && !pyFile.exists => Seq( - ResolvedMethod(Seq(x.fullName, "__init__").mkString(pathSep), alias), - ResolvedTypeDecl(x.fullName) - ) - // If we import the type, we also import the constructor - case x: ResolvedTypeDecl if isMaybeConstructor => - Seq(x, ResolvedMethod(Seq(x.fullName, "__init__").mkString(pathSep), alias)) - // If we can determine the import is a constructor, then it is likely not a member - case x: UnknownImport if isMaybeConstructor => + s"${path.replaceAll("\\.", sep)}${java.io.File.separator}$expEntity.py:" + ).toResolvedImport(cpg) + case f if f.isDirectory && (f / s"$expEntity.py").exists => Seq( - UnknownMethod(Seq(x.path, "__init__").mkString(pathSep), alias), - UnknownTypeDecl(x.path) + s"${(f / s"$expEntity.py").pathAsString.stripPrefix(codeRoot)}:" + ).toResolvedImport(cpg) + case _ => + s"${path.replaceAll("\\.", sep)}.py:$pathSep$expEntity".withInit.toResolvedImport( + cpg ) - case x => Seq(x) - }.toSet - end resolveEntities + ).flatMap { + // If we import the constructor, we also import the type + case x: ResolvedMethod if isMaybeConstructor => + Seq( + ResolvedMethod(Seq(x.fullName, "__init__").mkString(pathSep), alias), + ResolvedTypeDecl(x.fullName) + ) + // If we import the type, we also import the constructor + case x: ResolvedTypeDecl if isMaybeConstructor => + Seq(x, ResolvedMethod(Seq(x.fullName, "__init__").mkString(pathSep), alias)) + // If we can determine the import is a constructor, then it is likely not a member + case x: UnknownImport if isMaybeConstructor => + Seq( + UnknownMethod(Seq(x.path, "__init__").mkString(pathSep), alias), + UnknownTypeDecl(x.path) + ) + case x => Seq(x) + }.toSet + end resolveEntities end ImportResolverPass diff --git a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/ImportsPass.scala b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/ImportsPass.scala index c4f791e8..528f3b12 100644 --- a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/ImportsPass.scala +++ b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/ImportsPass.scala @@ -8,15 +8,15 @@ import io.shiftleft.semanticcpg.language.operatorextension.OpNodes.Assignment class ImportsPass(cpg: Cpg) extends XImportsPass(cpg): - override protected val importCallName: String = "import" + override protected val importCallName: String = "import" - override protected def importCallToPart(x: Call): Iterator[(Call, Assignment)] = - x.inAssignment.map(y => (x, y)) + override protected def importCallToPart(x: Call): Iterator[(Call, Assignment)] = + x.inAssignment.map(y => (x, y)) - override def importedEntityFromCall(call: Call): String = - call.argument.code.l match - case List("", what) => what - case List(where, what) => s"$where.$what" - case List("", what, _) => what - case List(where, what, _) => s"$where.$what" - case _ => "" + override def importedEntityFromCall(call: Call): String = + call.argument.code.l match + case List("", what) => what + case List(where, what) => s"$where.$what" + case List("", what, _) => what + case List(where, what, _) => s"$where.$what" + case _ => "" diff --git a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/Main.scala b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/Main.scala index f0291ba5..4f29e951 100644 --- a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/Main.scala +++ b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/Main.scala @@ -8,39 +8,39 @@ import scopt.OParser import java.nio.file.Paths private object Frontend: - val cmdLineParser: OParser[Unit, Py2CpgOnFileSystemConfig] = - val builder = OParser.builder[Py2CpgOnFileSystemConfig] - import builder.* - // Defaults for all command line options are specified in Py2CpgOFileSystemConfig - // because Scopt is a shit library. - OParser.sequence( - programName("pysrc2cpg"), - opt[String]("venvDir") - .text( - "Virtual environment directory. If not absolute it is interpreted relative to input-dir. Defaults to .venv." - ) - .action((dir, config) => config.withVenvDir(Paths.get(dir))), - opt[Boolean]("ignoreVenvDir") - .text("Specifies whether venv-dir is ignored. Default to true.") - .action(((value, config) => config.withIgnoreVenvDir(value))), - opt[Seq[String]]("ignore-paths") - .text( - "Ignores the specified path from analysis. If not absolute it is interpreted relative to input-dir." - ) - .action(((value, config) => config.withIgnorePaths(value.map(Paths.get(_))))), - opt[Seq[String]]("ignore-dir-names") - .text( - "Excludes all files where the relative path from input-dir contains at least one of names specified here." - ) - .action(((value, config) => config.withIgnoreDirNames(value))), - XTypeRecovery.parserOptions - ) - end cmdLineParser + val cmdLineParser: OParser[Unit, Py2CpgOnFileSystemConfig] = + val builder = OParser.builder[Py2CpgOnFileSystemConfig] + import builder.* + // Defaults for all command line options are specified in Py2CpgOFileSystemConfig + // because Scopt is a shit library. + OParser.sequence( + programName("pysrc2cpg"), + opt[String]("venvDir") + .text( + "Virtual environment directory. If not absolute it is interpreted relative to input-dir. Defaults to .venv." + ) + .action((dir, config) => config.withVenvDir(Paths.get(dir))), + opt[Boolean]("ignoreVenvDir") + .text("Specifies whether venv-dir is ignored. Default to true.") + .action(((value, config) => config.withIgnoreVenvDir(value))), + opt[Seq[String]]("ignore-paths") + .text( + "Ignores the specified path from analysis. If not absolute it is interpreted relative to input-dir." + ) + .action(((value, config) => config.withIgnorePaths(value.map(Paths.get(_))))), + opt[Seq[String]]("ignore-dir-names") + .text( + "Excludes all files where the relative path from input-dir contains at least one of names specified here." + ) + .action(((value, config) => config.withIgnoreDirNames(value))), + XTypeRecovery.parserOptions + ) + end cmdLineParser end Frontend object NewMain extends X2CpgMain(cmdLineParser, new Py2CpgOnFileSystem())(new Py2CpgOnFileSystemConfig()): - def run(config: Py2CpgOnFileSystemConfig, frontend: Py2CpgOnFileSystem): Unit = - frontend.run(config) + def run(config: Py2CpgOnFileSystemConfig, frontend: Py2CpgOnFileSystem): Unit = + frontend.run(config) - def getCmdLineParser = cmdLineParser + def getCmdLineParser = cmdLineParser diff --git a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/NodeBuilder.scala b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/NodeBuilder.scala index 60d97517..4fbb5d5f 100644 --- a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/NodeBuilder.scala +++ b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/NodeBuilder.scala @@ -9,318 +9,318 @@ import overflowdb.BatchedUpdate.DiffGraphBuilder class NodeBuilder(diffGraph: DiffGraphBuilder): - private def addNodeToDiff[T <: nodes.NewNode](node: T): T = - diffGraph.addNode(node) - node - - def callNode( - code: String, - name: String, - dispatchType: String, - lineAndColumn: LineAndColumn - ): nodes.NewCall = - val callNode = nodes - .NewCall() - .code(code) - .name(name) - .methodFullName(if dispatchType == DispatchTypes.STATIC_DISPATCH then name - else Defines.DynamicCallUnknownFullName) - .dispatchType(dispatchType) - .typeFullName(Constants.ANY) - .lineNumber(lineAndColumn.line) - .columnNumber(lineAndColumn.column) - addNodeToDiff(callNode) - - def typeNode(name: String, fullName: String): nodes.NewType = - val typeNode = nodes - .NewType() - .name(name) - .fullName(fullName) - .typeDeclFullName(fullName) - addNodeToDiff(typeNode) - - def typeDeclNode( - name: String, - fullName: String, - fileName: String, - inheritsFromFullNames: collection.Seq[String], - lineAndColumn: LineAndColumn - ): nodes.NewTypeDecl = - val typeDeclNode = nodes - .NewTypeDecl() - .name(name) - .fullName(fullName) - .isExternal(false) - .filename(fileName) - .inheritsFromTypeFullName(inheritsFromFullNames) - .lineNumber(lineAndColumn.line) - .columnNumber(lineAndColumn.column) - addNodeToDiff(typeDeclNode) - - def typeRefNode( - code: String, - typeFullName: String, - lineAndColumn: LineAndColumn - ): nodes.NewTypeRef = - val typeRefNode = nodes - .NewTypeRef() - .code(code) - .typeFullName(typeFullName) - .lineNumber(lineAndColumn.line) - .columnNumber(lineAndColumn.column) - addNodeToDiff(typeRefNode) - - def memberNode(name: String): nodes.NewMember = - val memberNode = nodes - .NewMember() - .code(name) - .name(name) - .typeFullName(Constants.ANY) - addNodeToDiff(memberNode) - - def memberNode(name: String, lineAndColumn: LineAndColumn): nodes.NewMember = - memberNode(name) - .lineNumber(lineAndColumn.line) - .columnNumber(lineAndColumn.column) - - def memberNode(name: String, dynamicTypeHintFullName: String): nodes.NewMember = - memberNode(name).dynamicTypeHintFullName(dynamicTypeHintFullName :: Nil) - def memberNode( - name: String, - dynamicTypeHintFullName: String, - lineAndColumn: LineAndColumn - ): nodes.NewMember = - memberNode(name, lineAndColumn).dynamicTypeHintFullName(dynamicTypeHintFullName :: Nil) - - def bindingNode(): nodes.NewBinding = - val bindingNode = nodes - .NewBinding() - .name("") - .signature("") - - addNodeToDiff(bindingNode) - - def methodNode( - name: String, - fullName: String, - fileName: String, - lineAndColumn: LineAndColumn - ): nodes.NewMethod = - val methodNode = nodes - .NewMethod() - .name(name) - .fullName(fullName) - .filename(fileName) - .isExternal(false) - .lineNumber(lineAndColumn.line) - .lineNumberEnd(lineAndColumn.endLine) - .columnNumber(lineAndColumn.column) - .columnNumberEnd(lineAndColumn.endColumn) - addNodeToDiff(methodNode) - - def methodRefNode( - name: String, - fullName: String, - lineAndColumn: LineAndColumn - ): nodes.NewMethodRef = - val methodRefNode = nodes - .NewMethodRef() - .code(name) - .methodFullName(fullName) - .typeFullName(fullName) - .lineNumber(lineAndColumn.line) - .columnNumber(lineAndColumn.column) - addNodeToDiff(methodRefNode) - - def closureBindingNode( - closureBindingId: String, - closureOriginalName: String - ): nodes.NewClosureBinding = - val closureBindingNode = nodes - .NewClosureBinding() - .closureBindingId(Some(closureBindingId)) - .evaluationStrategy(EvaluationStrategies.BY_REFERENCE) - .closureOriginalName(Some(closureOriginalName)) - addNodeToDiff(closureBindingNode) - - def methodParameterNode( - name: String, - isVariadic: Boolean, - lineAndColumn: LineAndColumn, - index: Option[Int] = None, - typeHint: Option[ast.iexpr] = None - ): nodes.NewMethodParameterIn = - val methodParameterNode = nodes - .NewMethodParameterIn() - .name(name) - .code(name) - .evaluationStrategy(EvaluationStrategies.BY_SHARING) - .typeFullName(extractTypesFromHint(typeHint).getOrElse(Constants.ANY)) - .isVariadic(isVariadic) - .lineNumber(lineAndColumn.line) - .columnNumber(lineAndColumn.column) - index.foreach(idx => methodParameterNode.index(idx)) - addNodeToDiff(methodParameterNode) - - def extractTypesFromHint(typeHint: Option[ast.iexpr] = None): Option[String] = - typeHint match - case Some(hint) => - val nameSequence = hint match - case n: ast.Name => Option(n.id) - // TODO: Definitely a place for follow up handling of generics - currently only take the polymorphic type - // without type args. To see the type arguments, see ast.Subscript.slice - case attr: ast.Attribute => - extractTypesFromHint(Some(attr.value)).map { x => x + "." + attr.attr } - case n: ast.Subscript if n.value.isInstanceOf[ast.Name] => - Option(n.value.asInstanceOf[ast.Name].id) - case n: ast.Constant if n.value.isInstanceOf[ast.StringConstant] => - Option(n.value.asInstanceOf[ast.StringConstant].value) - case _ => None - nameSequence.map { typeName => - if allBuiltinClasses.contains(typeName) then s"$builtinPrefix$typeName" - else if typingClassesV3.contains(typeName) then s"$typingPrefix$typeName" - else typeName - } - case _ => None - - def methodReturnNode( - staticTypeHint: Option[String], - dynamicTypeHintFullName: Option[String], - lineAndColumn: LineAndColumn - ): nodes.NewMethodReturn = - val methodReturnNode = NodeBuilders - .newMethodReturnNode( - staticTypeHint.getOrElse(Constants.ANY), - dynamicTypeHintFullName, - Some(lineAndColumn.line), - Some(lineAndColumn.column) - ) - .evaluationStrategy(EvaluationStrategies.BY_SHARING) - - addNodeToDiff(methodReturnNode) - - def returnNode(code: String, lineAndColumn: LineAndColumn): nodes.NewReturn = - val returnNode = nodes - .NewReturn() - .code(code) - .lineNumber(lineAndColumn.line) - .columnNumber(lineAndColumn.column) - - addNodeToDiff(returnNode) - - def identifierNode(name: String, lineAndColumn: LineAndColumn): nodes.NewIdentifier = - val identifierNode = nodes - .NewIdentifier() - .code(name) - .name(name) - .typeFullName(Constants.ANY) - .lineNumber(lineAndColumn.line) - .columnNumber(lineAndColumn.column) - addNodeToDiff(identifierNode) - - def fieldIdentifierNode(name: String, lineAndColumn: LineAndColumn): nodes.NewFieldIdentifier = - val fieldIdentifierNode = nodes - .NewFieldIdentifier() - .code(name) - .canonicalName(name) - .lineNumber(lineAndColumn.line) - .columnNumber(lineAndColumn.column) - addNodeToDiff(fieldIdentifierNode) - - def numberLiteralNode(number: Int, lineAndColumn: LineAndColumn): nodes.NewLiteral = - numberLiteralNode(number.toString, lineAndColumn) - - def numberLiteralNode(number: String, lineAndColumn: LineAndColumn): nodes.NewLiteral = - val literalNode = nodes - .NewLiteral() - .code(number) - .typeFullName(Constants.ANY) - .lineNumber(lineAndColumn.line) - .columnNumber(lineAndColumn.column) - addNodeToDiff(literalNode) - - def stringLiteralNode(string: String, lineAndColumn: LineAndColumn): nodes.NewLiteral = - val literalNode = nodes - .NewLiteral() - .code(string) - .typeFullName(Constants.ANY) - .lineNumber(lineAndColumn.line) - .columnNumber(lineAndColumn.column) - addNodeToDiff(literalNode) - - def blockNode(code: String, lineAndColumn: LineAndColumn): nodes.NewBlock = - val blockNode = nodes - .NewBlock() - .code(code) - .typeFullName(Constants.ANY) - .lineNumber(lineAndColumn.line) - .columnNumber(lineAndColumn.column) - addNodeToDiff(blockNode) - - def controlStructureNode( - code: String, - controlStructureName: String, - lineAndColumn: LineAndColumn - ): nodes.NewControlStructure = - val controlStructureNode = nodes - .NewControlStructure() - .code(code) - .controlStructureType(controlStructureName) - .lineNumber(lineAndColumn.line) - .columnNumber(lineAndColumn.column) - addNodeToDiff(controlStructureNode) - - def localNode(name: String, closureBindingId: Option[String] = None): nodes.NewLocal = - val localNode = nodes - .NewLocal() - .code(name) - .name(name) - .closureBindingId(closureBindingId) - .typeFullName(Constants.ANY) - addNodeToDiff(localNode) - - def fileNode(fileName: String): nodes.NewFile = - val fileNode = nodes - .NewFile() - .name(fileName) - addNodeToDiff(fileNode) - - def namespaceBlockNode( - name: String, - fullName: String, - fileName: String - ): nodes.NewNamespaceBlock = - val namespaceBlockNode = nodes - .NewNamespaceBlock() - .name(name) - .fullName(fullName) - .filename(fileName) - addNodeToDiff(namespaceBlockNode) - - def modifierNode(modifierType: String): nodes.NewModifier = - val modifierNode = nodes - .NewModifier() - .modifierType(modifierType) - addNodeToDiff(modifierNode) - - def metaNode(language: String, version: String): nodes.NewMetaData = - val metaNode = nodes - .NewMetaData() - .language(language) - .version(version) - addNodeToDiff(metaNode) - - def unknownNode( - code: String, - parserTypeName: String, - lineAndColumn: LineAndColumn - ): nodes.NewUnknown = - val unknownNode = nodes - .NewUnknown() - .code(code) - .parserTypeName(parserTypeName) - .typeFullName(Constants.ANY) - .lineNumber(lineAndColumn.line) - .columnNumber(lineAndColumn.column) - addNodeToDiff(unknownNode) + private def addNodeToDiff[T <: nodes.NewNode](node: T): T = + diffGraph.addNode(node) + node + + def callNode( + code: String, + name: String, + dispatchType: String, + lineAndColumn: LineAndColumn + ): nodes.NewCall = + val callNode = nodes + .NewCall() + .code(code) + .name(name) + .methodFullName(if dispatchType == DispatchTypes.STATIC_DISPATCH then name + else Defines.DynamicCallUnknownFullName) + .dispatchType(dispatchType) + .typeFullName(Constants.ANY) + .lineNumber(lineAndColumn.line) + .columnNumber(lineAndColumn.column) + addNodeToDiff(callNode) + + def typeNode(name: String, fullName: String): nodes.NewType = + val typeNode = nodes + .NewType() + .name(name) + .fullName(fullName) + .typeDeclFullName(fullName) + addNodeToDiff(typeNode) + + def typeDeclNode( + name: String, + fullName: String, + fileName: String, + inheritsFromFullNames: collection.Seq[String], + lineAndColumn: LineAndColumn + ): nodes.NewTypeDecl = + val typeDeclNode = nodes + .NewTypeDecl() + .name(name) + .fullName(fullName) + .isExternal(false) + .filename(fileName) + .inheritsFromTypeFullName(inheritsFromFullNames) + .lineNumber(lineAndColumn.line) + .columnNumber(lineAndColumn.column) + addNodeToDiff(typeDeclNode) + + def typeRefNode( + code: String, + typeFullName: String, + lineAndColumn: LineAndColumn + ): nodes.NewTypeRef = + val typeRefNode = nodes + .NewTypeRef() + .code(code) + .typeFullName(typeFullName) + .lineNumber(lineAndColumn.line) + .columnNumber(lineAndColumn.column) + addNodeToDiff(typeRefNode) + + def memberNode(name: String): nodes.NewMember = + val memberNode = nodes + .NewMember() + .code(name) + .name(name) + .typeFullName(Constants.ANY) + addNodeToDiff(memberNode) + + def memberNode(name: String, lineAndColumn: LineAndColumn): nodes.NewMember = + memberNode(name) + .lineNumber(lineAndColumn.line) + .columnNumber(lineAndColumn.column) + + def memberNode(name: String, dynamicTypeHintFullName: String): nodes.NewMember = + memberNode(name).dynamicTypeHintFullName(dynamicTypeHintFullName :: Nil) + def memberNode( + name: String, + dynamicTypeHintFullName: String, + lineAndColumn: LineAndColumn + ): nodes.NewMember = + memberNode(name, lineAndColumn).dynamicTypeHintFullName(dynamicTypeHintFullName :: Nil) + + def bindingNode(): nodes.NewBinding = + val bindingNode = nodes + .NewBinding() + .name("") + .signature("") + + addNodeToDiff(bindingNode) + + def methodNode( + name: String, + fullName: String, + fileName: String, + lineAndColumn: LineAndColumn + ): nodes.NewMethod = + val methodNode = nodes + .NewMethod() + .name(name) + .fullName(fullName) + .filename(fileName) + .isExternal(false) + .lineNumber(lineAndColumn.line) + .lineNumberEnd(lineAndColumn.endLine) + .columnNumber(lineAndColumn.column) + .columnNumberEnd(lineAndColumn.endColumn) + addNodeToDiff(methodNode) + + def methodRefNode( + name: String, + fullName: String, + lineAndColumn: LineAndColumn + ): nodes.NewMethodRef = + val methodRefNode = nodes + .NewMethodRef() + .code(name) + .methodFullName(fullName) + .typeFullName(fullName) + .lineNumber(lineAndColumn.line) + .columnNumber(lineAndColumn.column) + addNodeToDiff(methodRefNode) + + def closureBindingNode( + closureBindingId: String, + closureOriginalName: String + ): nodes.NewClosureBinding = + val closureBindingNode = nodes + .NewClosureBinding() + .closureBindingId(Some(closureBindingId)) + .evaluationStrategy(EvaluationStrategies.BY_REFERENCE) + .closureOriginalName(Some(closureOriginalName)) + addNodeToDiff(closureBindingNode) + + def methodParameterNode( + name: String, + isVariadic: Boolean, + lineAndColumn: LineAndColumn, + index: Option[Int] = None, + typeHint: Option[ast.iexpr] = None + ): nodes.NewMethodParameterIn = + val methodParameterNode = nodes + .NewMethodParameterIn() + .name(name) + .code(name) + .evaluationStrategy(EvaluationStrategies.BY_SHARING) + .typeFullName(extractTypesFromHint(typeHint).getOrElse(Constants.ANY)) + .isVariadic(isVariadic) + .lineNumber(lineAndColumn.line) + .columnNumber(lineAndColumn.column) + index.foreach(idx => methodParameterNode.index(idx)) + addNodeToDiff(methodParameterNode) + + def extractTypesFromHint(typeHint: Option[ast.iexpr] = None): Option[String] = + typeHint match + case Some(hint) => + val nameSequence = hint match + case n: ast.Name => Option(n.id) + // TODO: Definitely a place for follow up handling of generics - currently only take the polymorphic type + // without type args. To see the type arguments, see ast.Subscript.slice + case attr: ast.Attribute => + extractTypesFromHint(Some(attr.value)).map { x => x + "." + attr.attr } + case n: ast.Subscript if n.value.isInstanceOf[ast.Name] => + Option(n.value.asInstanceOf[ast.Name].id) + case n: ast.Constant if n.value.isInstanceOf[ast.StringConstant] => + Option(n.value.asInstanceOf[ast.StringConstant].value) + case _ => None + nameSequence.map { typeName => + if allBuiltinClasses.contains(typeName) then s"$builtinPrefix$typeName" + else if typingClassesV3.contains(typeName) then s"$typingPrefix$typeName" + else typeName + } + case _ => None + + def methodReturnNode( + staticTypeHint: Option[String], + dynamicTypeHintFullName: Option[String], + lineAndColumn: LineAndColumn + ): nodes.NewMethodReturn = + val methodReturnNode = NodeBuilders + .newMethodReturnNode( + staticTypeHint.getOrElse(Constants.ANY), + dynamicTypeHintFullName, + Some(lineAndColumn.line), + Some(lineAndColumn.column) + ) + .evaluationStrategy(EvaluationStrategies.BY_SHARING) + + addNodeToDiff(methodReturnNode) + + def returnNode(code: String, lineAndColumn: LineAndColumn): nodes.NewReturn = + val returnNode = nodes + .NewReturn() + .code(code) + .lineNumber(lineAndColumn.line) + .columnNumber(lineAndColumn.column) + + addNodeToDiff(returnNode) + + def identifierNode(name: String, lineAndColumn: LineAndColumn): nodes.NewIdentifier = + val identifierNode = nodes + .NewIdentifier() + .code(name) + .name(name) + .typeFullName(Constants.ANY) + .lineNumber(lineAndColumn.line) + .columnNumber(lineAndColumn.column) + addNodeToDiff(identifierNode) + + def fieldIdentifierNode(name: String, lineAndColumn: LineAndColumn): nodes.NewFieldIdentifier = + val fieldIdentifierNode = nodes + .NewFieldIdentifier() + .code(name) + .canonicalName(name) + .lineNumber(lineAndColumn.line) + .columnNumber(lineAndColumn.column) + addNodeToDiff(fieldIdentifierNode) + + def numberLiteralNode(number: Int, lineAndColumn: LineAndColumn): nodes.NewLiteral = + numberLiteralNode(number.toString, lineAndColumn) + + def numberLiteralNode(number: String, lineAndColumn: LineAndColumn): nodes.NewLiteral = + val literalNode = nodes + .NewLiteral() + .code(number) + .typeFullName(Constants.ANY) + .lineNumber(lineAndColumn.line) + .columnNumber(lineAndColumn.column) + addNodeToDiff(literalNode) + + def stringLiteralNode(string: String, lineAndColumn: LineAndColumn): nodes.NewLiteral = + val literalNode = nodes + .NewLiteral() + .code(string) + .typeFullName(Constants.ANY) + .lineNumber(lineAndColumn.line) + .columnNumber(lineAndColumn.column) + addNodeToDiff(literalNode) + + def blockNode(code: String, lineAndColumn: LineAndColumn): nodes.NewBlock = + val blockNode = nodes + .NewBlock() + .code(code) + .typeFullName(Constants.ANY) + .lineNumber(lineAndColumn.line) + .columnNumber(lineAndColumn.column) + addNodeToDiff(blockNode) + + def controlStructureNode( + code: String, + controlStructureName: String, + lineAndColumn: LineAndColumn + ): nodes.NewControlStructure = + val controlStructureNode = nodes + .NewControlStructure() + .code(code) + .controlStructureType(controlStructureName) + .lineNumber(lineAndColumn.line) + .columnNumber(lineAndColumn.column) + addNodeToDiff(controlStructureNode) + + def localNode(name: String, closureBindingId: Option[String] = None): nodes.NewLocal = + val localNode = nodes + .NewLocal() + .code(name) + .name(name) + .closureBindingId(closureBindingId) + .typeFullName(Constants.ANY) + addNodeToDiff(localNode) + + def fileNode(fileName: String): nodes.NewFile = + val fileNode = nodes + .NewFile() + .name(fileName) + addNodeToDiff(fileNode) + + def namespaceBlockNode( + name: String, + fullName: String, + fileName: String + ): nodes.NewNamespaceBlock = + val namespaceBlockNode = nodes + .NewNamespaceBlock() + .name(name) + .fullName(fullName) + .filename(fileName) + addNodeToDiff(namespaceBlockNode) + + def modifierNode(modifierType: String): nodes.NewModifier = + val modifierNode = nodes + .NewModifier() + .modifierType(modifierType) + addNodeToDiff(modifierNode) + + def metaNode(language: String, version: String): nodes.NewMetaData = + val metaNode = nodes + .NewMetaData() + .language(language) + .version(version) + addNodeToDiff(metaNode) + + def unknownNode( + code: String, + parserTypeName: String, + lineAndColumn: LineAndColumn + ): nodes.NewUnknown = + val unknownNode = nodes + .NewUnknown() + .code(code) + .parserTypeName(parserTypeName) + .typeFullName(Constants.ANY) + .lineNumber(lineAndColumn.line) + .columnNumber(lineAndColumn.column) + addNodeToDiff(unknownNode) end NodeBuilder diff --git a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/NodeToCode.scala b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/NodeToCode.scala index e37bbd4a..1d3fc458 100644 --- a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/NodeToCode.scala +++ b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/NodeToCode.scala @@ -3,5 +3,5 @@ package io.appthreat.pysrc2cpg import io.appthreat.pythonparser.ast class NodeToCode(content: String): - def getCode(node: ast.iattributes): String = - content.substring(node.input_offset, node.end_input_offset) + def getCode(node: ast.iattributes): String = + content.substring(node.input_offset, node.end_input_offset) diff --git a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/Py2Cpg.scala b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/Py2Cpg.scala index aea27100..32eda8d1 100644 --- a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/Py2Cpg.scala +++ b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/Py2Cpg.scala @@ -7,8 +7,8 @@ import overflowdb.BatchedUpdate import overflowdb.BatchedUpdate.DiffGraphBuilder object Py2Cpg: - case class InputPair(content: String, relFileName: String) - type InputProvider = () => InputPair + case class InputPair(content: String, relFileName: String) + type InputProvider = () => InputPair /** Entry point for general cpg generation from python code. * @@ -31,32 +31,32 @@ class Py2Cpg( requirementsTxt: String = "requirements.txt", schemaValidationMode: ValidationMode ): - private val diffGraph = new DiffGraphBuilder() - private val nodeBuilder = new NodeBuilder(diffGraph) - private val edgeBuilder = new EdgeBuilder(diffGraph) + private val diffGraph = new DiffGraphBuilder() + private val nodeBuilder = new NodeBuilder(diffGraph) + private val edgeBuilder = new EdgeBuilder(diffGraph) - def buildCpg(): Unit = - nodeBuilder.metaNode(Languages.PYTHONSRC, version = "").root( - inputPath + java.io.File.separator + def buildCpg(): Unit = + nodeBuilder.metaNode(Languages.PYTHONSRC, version = "").root( + inputPath + java.io.File.separator + ) + val globalNamespaceBlock = + nodeBuilder.namespaceBlockNode( + Constants.GLOBAL_NAMESPACE, + Constants.GLOBAL_NAMESPACE, + "N/A" ) - val globalNamespaceBlock = - nodeBuilder.namespaceBlockNode( - Constants.GLOBAL_NAMESPACE, - Constants.GLOBAL_NAMESPACE, - "N/A" - ) - nodeBuilder.typeNode(Constants.ANY, Constants.ANY) - val anyTypeDecl = nodeBuilder.typeDeclNode( - Constants.ANY, - Constants.ANY, - "N/A", - Nil, - LineAndColumn(1, 1, 1, 1) - ) - edgeBuilder.astEdge(anyTypeDecl, globalNamespaceBlock, 0) - BatchedUpdate.applyDiff(outputCpg.graph, diffGraph) - new CodeToCpg(outputCpg, inputProviders, schemaValidationMode).createAndApply() - new ConfigFileCreationPass(outputCpg, requirementsTxt).createAndApply() - new DependenciesFromRequirementsTxtPass(outputCpg).createAndApply() - end buildCpg + nodeBuilder.typeNode(Constants.ANY, Constants.ANY) + val anyTypeDecl = nodeBuilder.typeDeclNode( + Constants.ANY, + Constants.ANY, + "N/A", + Nil, + LineAndColumn(1, 1, 1, 1) + ) + edgeBuilder.astEdge(anyTypeDecl, globalNamespaceBlock, 0) + BatchedUpdate.applyDiff(outputCpg.graph, diffGraph) + new CodeToCpg(outputCpg, inputProviders, schemaValidationMode).createAndApply() + new ConfigFileCreationPass(outputCpg, requirementsTxt).createAndApply() + new DependenciesFromRequirementsTxtPass(outputCpg).createAndApply() + end buildCpg end Py2Cpg diff --git a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/Py2CpgOnFileSystem.scala b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/Py2CpgOnFileSystem.scala index a255b883..a5238cbe 100644 --- a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/Py2CpgOnFileSystem.scala +++ b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/Py2CpgOnFileSystem.scala @@ -18,87 +18,87 @@ case class Py2CpgOnFileSystemConfig( requirementsTxt: String = "requirements.txt" ) extends X2CpgConfig[Py2CpgOnFileSystemConfig] with TypeRecoveryParserConfig[Py2CpgOnFileSystemConfig]: - def withVenvDir(venvDir: Path): Py2CpgOnFileSystemConfig = - copy(venvDir = venvDir).withInheritedFields(this) + def withVenvDir(venvDir: Path): Py2CpgOnFileSystemConfig = + copy(venvDir = venvDir).withInheritedFields(this) - def withIgnoreVenvDir(value: Boolean): Py2CpgOnFileSystemConfig = - copy(ignoreVenvDir = value).withInheritedFields(this) + def withIgnoreVenvDir(value: Boolean): Py2CpgOnFileSystemConfig = + copy(ignoreVenvDir = value).withInheritedFields(this) - def withIgnorePaths(value: Seq[Path]): Py2CpgOnFileSystemConfig = - copy(ignorePaths = value).withInheritedFields(this) + def withIgnorePaths(value: Seq[Path]): Py2CpgOnFileSystemConfig = + copy(ignorePaths = value).withInheritedFields(this) - def withIgnoreDirNames(value: Seq[String]): Py2CpgOnFileSystemConfig = - copy(ignoreDirNames = value).withInheritedFields(this) + def withIgnoreDirNames(value: Seq[String]): Py2CpgOnFileSystemConfig = + copy(ignoreDirNames = value).withInheritedFields(this) - def withRequirementsTxt(text: String): Py2CpgOnFileSystemConfig = - copy(requirementsTxt = text).withInheritedFields(this) + def withRequirementsTxt(text: String): Py2CpgOnFileSystemConfig = + copy(requirementsTxt = text).withInheritedFields(this) end Py2CpgOnFileSystemConfig class Py2CpgOnFileSystem extends X2CpgFrontend[Py2CpgOnFileSystemConfig]: - private val logger = LoggerFactory.getLogger(getClass) - - /** Entry point for files system based cpg generation from python code. - * @param config - * Configuration for cpg generation. - */ - override def createCpg(config: Py2CpgOnFileSystemConfig): Try[Cpg] = - logConfiguration(config) - - X2Cpg.withNewEmptyCpg(config.outputPath, config) { (cpg, _) => - val venvIgnorePath = - if config.ignoreVenvDir then - config.venvDir :: Nil - else - Nil - val inputPath = Path.of(config.inputPath) - val ignoreDirNamesSet = config.ignoreDirNames.toSet - val absoluteIgnorePaths = (config.ignorePaths ++ venvIgnorePath).map { path => - inputPath.resolve(path) - } - - val inputFiles = SourceFiles - .determine(config.inputPath, Set(".py"), config) - .map(x => Path.of(x)) - .filter { file => filterIgnoreDirNames(file, inputPath, ignoreDirNamesSet) } - .filter { file => - !absoluteIgnorePaths.exists(ignorePath => file.startsWith(ignorePath)) - } - - val inputProviders = inputFiles.map { inputFile => () => - val content = IOUtils.readLinesInFile(inputFile).mkString("\n") - Py2Cpg.InputPair(content, inputPath.relativize(inputFile).toString) - } - val py2Cpg = new Py2Cpg( - inputProviders, - cpg, - config.inputPath, - config.requirementsTxt, - config.schemaValidation - ) - py2Cpg.buildCpg() - } - end createCpg - - private def filterIgnoreDirNames( - file: Path, - inputPath: Path, - ignoreDirNamesSet: Set[String] - ): Boolean = - var parts = inputPath.relativize(file).iterator().asScala.toList - - if !Files.isDirectory(file) then - // we're only interested in the directories - drop the file part - parts = parts.dropRight(1) - - val aPartIsInIgnoreSet = parts.exists(part => ignoreDirNamesSet.contains(part.toString)) - !aPartIsInIgnoreSet - - private def logConfiguration(config: Py2CpgOnFileSystemConfig): Unit = - logger.debug(s"Output file: ${config.outputPath}") - logger.debug(s"Input directory: ${config.inputPath}") - logger.debug(s"Venv directory: ${config.venvDir}") - logger.debug(s"IgnoreVenvDir: ${config.ignoreVenvDir}") - logger.debug(s"IgnorePaths: ${config.ignorePaths.mkString(", ")}") - logger.debug(s"IgnoreDirNames: ${config.ignoreDirNames.mkString(", ")}") - logger.debug(s"No dummy types: ${config.disableDummyTypes}") + private val logger = LoggerFactory.getLogger(getClass) + + /** Entry point for files system based cpg generation from python code. + * @param config + * Configuration for cpg generation. + */ + override def createCpg(config: Py2CpgOnFileSystemConfig): Try[Cpg] = + logConfiguration(config) + + X2Cpg.withNewEmptyCpg(config.outputPath, config) { (cpg, _) => + val venvIgnorePath = + if config.ignoreVenvDir then + config.venvDir :: Nil + else + Nil + val inputPath = Path.of(config.inputPath) + val ignoreDirNamesSet = config.ignoreDirNames.toSet + val absoluteIgnorePaths = (config.ignorePaths ++ venvIgnorePath).map { path => + inputPath.resolve(path) + } + + val inputFiles = SourceFiles + .determine(config.inputPath, Set(".py"), config) + .map(x => Path.of(x)) + .filter { file => filterIgnoreDirNames(file, inputPath, ignoreDirNamesSet) } + .filter { file => + !absoluteIgnorePaths.exists(ignorePath => file.startsWith(ignorePath)) + } + + val inputProviders = inputFiles.map { inputFile => () => + val content = IOUtils.readLinesInFile(inputFile).mkString("\n") + Py2Cpg.InputPair(content, inputPath.relativize(inputFile).toString) + } + val py2Cpg = new Py2Cpg( + inputProviders, + cpg, + config.inputPath, + config.requirementsTxt, + config.schemaValidation + ) + py2Cpg.buildCpg() + } + end createCpg + + private def filterIgnoreDirNames( + file: Path, + inputPath: Path, + ignoreDirNamesSet: Set[String] + ): Boolean = + var parts = inputPath.relativize(file).iterator().asScala.toList + + if !Files.isDirectory(file) then + // we're only interested in the directories - drop the file part + parts = parts.dropRight(1) + + val aPartIsInIgnoreSet = parts.exists(part => ignoreDirNamesSet.contains(part.toString)) + !aPartIsInIgnoreSet + + private def logConfiguration(config: Py2CpgOnFileSystemConfig): Unit = + logger.debug(s"Output file: ${config.outputPath}") + logger.debug(s"Input directory: ${config.inputPath}") + logger.debug(s"Venv directory: ${config.venvDir}") + logger.debug(s"IgnoreVenvDir: ${config.ignoreVenvDir}") + logger.debug(s"IgnorePaths: ${config.ignorePaths.mkString(", ")}") + logger.debug(s"IgnoreDirNames: ${config.ignoreDirNames.mkString(", ")}") + logger.debug(s"No dummy types: ${config.disableDummyTypes}") end Py2CpgOnFileSystem diff --git a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/PythonAstVisitor.scala b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/PythonAstVisitor.scala index 8e62fb05..99bc9ddd 100644 --- a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/PythonAstVisitor.scala +++ b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/PythonAstVisitor.scala @@ -18,8 +18,8 @@ import overflowdb.BatchedUpdate.DiffGraphBuilder import scala.collection.mutable object MethodParameters: - def empty(): MethodParameters = - new MethodParameters(0, Nil) + def empty(): MethodParameters = + new MethodParameters(0, Nil) case class MethodParameters( posStartIndex: Int, positionalParams: Iterable[nodes.NewMethodParameterIn] @@ -36,2398 +36,2399 @@ class PythonAstVisitor( version: PythonVersion )(implicit withSchemaValidation: ValidationMode) extends PythonAstVisitorHelpers: - private val diffGraph = new DiffGraphBuilder() - protected val nodeBuilder = new NodeBuilder(diffGraph) - protected val edgeBuilder = new EdgeBuilder(diffGraph) + private val diffGraph = new DiffGraphBuilder() + protected val nodeBuilder = new NodeBuilder(diffGraph) + protected val edgeBuilder = new EdgeBuilder(diffGraph) - protected val contextStack = new ContextStack() + protected val contextStack = new ContextStack() - private var memOpMap: AstNodeToMemoryOperationMap = scala.compiletime.uninitialized + private var memOpMap: AstNodeToMemoryOperationMap = scala.compiletime.uninitialized - private val members = mutable.Map.empty[NewTypeDecl, List[String]] + private val members = mutable.Map.empty[NewTypeDecl, List[String]] - // As key only ast.FunctionDef and ast.AsyncFunctionDef are used but there - // is no more specific type than ast.istmt. - private val functionDefToMethod = mutable.Map.empty[ast.istmt, nodes.NewMethod] + // As key only ast.FunctionDef and ast.AsyncFunctionDef are used but there + // is no more specific type than ast.istmt. + private val functionDefToMethod = mutable.Map.empty[ast.istmt, nodes.NewMethod] - def getDiffGraph: DiffGraphBuilder = - diffGraph + def getDiffGraph: DiffGraphBuilder = + diffGraph - private def createIdentifierLinks(): Unit = - contextStack.createIdentifierLinks( - nodeBuilder.localNode, - nodeBuilder.closureBindingNode, - edgeBuilder.astEdge, - edgeBuilder.refEdge, - edgeBuilder.captureEdge - ) + private def createIdentifierLinks(): Unit = + contextStack.createIdentifierLinks( + nodeBuilder.localNode, + nodeBuilder.closureBindingNode, + edgeBuilder.astEdge, + edgeBuilder.refEdge, + edgeBuilder.captureEdge + ) - def convert(astNode: ast.iast): NewNode = - astNode match - case module: ast.Module => convert(module) - - def convert(mod: ast.imod): NewNode = - mod match - case node: ast.Module => convert(node) - - // Entry method for the visitor. - def convert(module: ast.Module): NewNode = - val memOpCalculator = new MemoryOperationCalculator() - module.accept(memOpCalculator) - memOpMap = memOpCalculator.astNodeToMemOp - - val fileNode = nodeBuilder.fileNode(relFileName) - val namespaceBlockNode = - nodeBuilder.namespaceBlockNode( - Constants.GLOBAL_NAMESPACE, - relFileName + ":" + Constants.GLOBAL_NAMESPACE, - relFileName - ) - edgeBuilder.astEdge(namespaceBlockNode, fileNode, 1) - contextStack.setFileNamespaceBlock(namespaceBlockNode) - - val methodFullName = calculateFullNameFromContext("") - - val firstLineAndCol = module.stmts.headOption.map(lineAndColOf) - val lastLineAndCol = module.stmts.lastOption.map(lineAndColOf) - val line = firstLineAndCol.map(_.line).getOrElse(1) - val column = firstLineAndCol.map(_.column).getOrElse(1) - val endLine = lastLineAndCol.map(_.endLine).getOrElse(1) - val endColumn = lastLineAndCol.map(_.endColumn).getOrElse(1) - - val moduleMethodNode = - createMethod( - "", - methodFullName, - Some(""), - parameterProvider = () => MethodParameters.empty(), - bodyProvider = () => - createBuiltinIdentifiers(memOpCalculator.names) ++ module.stmts.map(convert), - returns = None, - isAsync = false, - methodRefNode = None, - returnTypeHint = None, - LineAndColumn(line, column, endLine, endColumn) - ) + def convert(astNode: ast.iast): NewNode = + astNode match + case module: ast.Module => convert(module) - createIdentifierLinks() - - moduleMethodNode - end convert - - // Create assignments of type references to all builtin identifiers if the identifier appears - // at least once in the module. We filter in order to not generate the complete list of - // assignment in each module. - // That logic still generates assignments in cases where we would not need to, but figuring - // that out would mean we need to wait until all identifiers are linked which is than too - // late to create new identifiers and still use the same link mechanism. We would need - // to rearrange quite some code to accomplish that. So we leave that as an optional TODO. - // Note that namesUsedInModule is only calculated from ast.Name nodes! So e.g. new names - // artificially created during lowering are not in that collection which is fine for now. - private def createBuiltinIdentifiers(namesUsedInModule: collection.Set[String]) - : Iterable[nodes.NewNode] = - val result = mutable.ArrayBuffer.empty[nodes.NewNode] - val lineAndColumn = LineAndColumn(1, 1, 1, 1) - - val builtinFunctions = mutable.ArrayBuffer.empty[String] - val builtinClasses = mutable.ArrayBuffer.empty[String] - - if version == PythonV3 || version == PythonV2AndV3 then - builtinFunctions.appendAll(PythonAstVisitor.builtinFunctionsV3) - builtinClasses.appendAll(PythonAstVisitor.builtinClassesV3) - if version == PythonV2 || version == PythonV2AndV3 then - builtinFunctions.appendAll(PythonAstVisitor.builtinFunctionsV2) - builtinClasses.appendAll(PythonAstVisitor.builtinClassesV2) - - builtinFunctions.distinct.foreach { builtinObjectName => - if namesUsedInModule.contains(builtinObjectName) then - val assignmentNode = createAssignment( - createIdentifierNode(builtinObjectName, Store, lineAndColumn), - nodeBuilder - .typeRefNode( - "__builtins__." + builtinObjectName, - builtinPrefix + builtinObjectName, - lineAndColumn - ), - lineAndColumn - ) + def convert(mod: ast.imod): NewNode = + mod match + case node: ast.Module => convert(node) - result.append(assignmentNode) - } - - builtinClasses.distinct.foreach { builtinObjectName => - if namesUsedInModule.contains(builtinObjectName) then - val assignmentNode = createAssignment( - createIdentifierNode(builtinObjectName, Store, lineAndColumn), - nodeBuilder.typeRefNode( - "__builtins__." + builtinObjectName, - builtinPrefix + builtinObjectName + metaClassSuffix, - lineAndColumn - ), - lineAndColumn - ) - - result.append(assignmentNode) - } + // Entry method for the visitor. + def convert(module: ast.Module): NewNode = + val memOpCalculator = new MemoryOperationCalculator() + module.accept(memOpCalculator) + memOpMap = memOpCalculator.astNodeToMemOp - result - end createBuiltinIdentifiers - - private def unhandled(node: ast.iast & ast.iattributes): NewNode = - val unhandledAsUnknown = true - if unhandledAsUnknown then - nodeBuilder.unknownNode(node.toString, node.getClass.getName, lineAndColOf(node)) - else - throw new NotImplementedError() - - def convert(stmt: ast.istmt): NewNode = - stmt match - case node: ast.FunctionDef => convert(node) - case node: ast.AsyncFunctionDef => convert(node) - case node: ast.ClassDef => convert(node) - case node: ast.Return => convert(node) - case node: ast.Delete => convert(node) - case node: ast.Assign => convert(node) - case node: ast.AnnAssign => convert(node) - case node: ast.AugAssign => convert(node) - case node: ast.For => convert(node) - case node: ast.AsyncFor => convert(node) - case node: ast.While => convert(node) - case node: ast.If => convert(node) - case node: ast.With => convert(node) - case node: ast.AsyncWith => convert(node) - case node: ast.Match => convert(node) - case node: ast.Raise => convert(node) - case node: ast.Try => convert(node) - case node: ast.Assert => convert(node) - case node: ast.Import => convert(node) - case node: ast.ImportFrom => convert(node) - case node: ast.Global => convert(node) - case node: ast.Nonlocal => convert(node) - case node: ast.Expr => convert(node) - case node: ast.Pass => convert(node) - case node: ast.Break => convert(node) - case node: ast.Continue => convert(node) - case node: ast.RaiseP2 => unhandled(node) - case node: ast.ErrorStatement => convert(node) - - def convert(functionDef: ast.FunctionDef): NewNode = - val methodIdentifierNode = - createIdentifierNode(functionDef.name, Store, lineAndColOf(functionDef)) - val (methodNode, methodRefNode) = createMethodAndMethodRef( - functionDef.name, - Some(functionDef.name), - createParameterProcessingFunction( - functionDef.args, - isStaticMethod(functionDef.decorator_list) - ), - () => functionDef.body.map(convert), - functionDef.returns, - isAsync = false, - lineAndColOf(functionDef) - ) - functionDefToMethod.put(functionDef, methodNode) - - val wrappedMethodRefNode = - wrapMethodRefWithDecorators(methodRefNode, functionDef.decorator_list) - - createAssignment(methodIdentifierNode, wrappedMethodRefNode, lineAndColOf(functionDef)) - end convert - - /* - * For a decorated function like: - * @f1(arg) - * @f2 - * def func(): pass - * - * The lowering is: - * func = f1(arg)(f2(func)) - * - * This function takes a method ref, wraps it in the decorator calls and returns the resulting expression. - * In the example case this is: - * f1(arg)(f2(func)) - */ - def wrapMethodRefWithDecorators( - methodRefNode: nodes.NewNode, - decoratorList: Iterable[ast.iexpr] - ): nodes.NewNode = - decoratorList.foldRight(methodRefNode)( - (decorator: ast.iexpr, wrappedMethodRef: nodes.NewNode) => - val (decoratorNode, decoratorName) = convert(decorator) match - case decoratorNode: NewIdentifier => decoratorNode -> decoratorNode.name - case decoratorNode => - decoratorNode -> "" // other decorators are dynamic so we leave this blank - createCall( - decoratorNode, - decoratorName, - lineAndColOf(decorator), - wrappedMethodRef :: Nil, - Nil - ) - ) - - def convert(functionDef: ast.AsyncFunctionDef): NewNode = - val methodIdentifierNode = - createIdentifierNode(functionDef.name, Store, lineAndColOf(functionDef)) - val (methodNode, methodRefNode) = createMethodAndMethodRef( - functionDef.name, - Some(functionDef.name), - createParameterProcessingFunction( - functionDef.args, - isStaticMethod(functionDef.decorator_list) - ), - () => functionDef.body.map(convert), - functionDef.returns, - isAsync = true, - lineAndColOf(functionDef) + val fileNode = nodeBuilder.fileNode(relFileName) + val namespaceBlockNode = + nodeBuilder.namespaceBlockNode( + Constants.GLOBAL_NAMESPACE, + relFileName + ":" + Constants.GLOBAL_NAMESPACE, + relFileName ) - functionDefToMethod.put(functionDef, methodNode) + edgeBuilder.astEdge(namespaceBlockNode, fileNode, 1) + contextStack.setFileNamespaceBlock(namespaceBlockNode) - val wrappedMethodRefNode = - wrapMethodRefWithDecorators(methodRefNode, functionDef.decorator_list) + val methodFullName = calculateFullNameFromContext("") - createAssignment(methodIdentifierNode, wrappedMethodRefNode, lineAndColOf(functionDef)) - end convert - - private def isStaticMethod(decoratorList: Iterable[ast.iexpr]): Boolean = - decoratorList.exists { - case name: ast.Name if name.id == "staticmethod" => true - case _ => false - } - - private def isClassMethod(decoratorList: Iterable[ast.iexpr]): Boolean = - decoratorList.exists { - case name: ast.Name if name.id == "classmethod" => true - case _ => false - } - - private def createParameterProcessingFunction( - parameters: ast.Arguments, - isStatic: Boolean - ): () => MethodParameters = - val startIndex = - if contextStack.isClassContext && !isStatic then - 0 - else - 1 - - () => new MethodParameters(startIndex, convert(parameters, startIndex)) - - // TODO handle returns - private def createMethodAndMethodRef( - methodName: String, - scopeName: Option[String], - parameterProvider: () => MethodParameters, - bodyProvider: () => Iterable[nodes.NewNode], - returns: Option[ast.iexpr], - isAsync: Boolean, - lineAndColumn: LineAndColumn - ): (nodes.NewMethod, nodes.NewMethodRef) = - val methodFullName = calculateFullNameFromContext(methodName) - - val methodRefNode = - nodeBuilder.methodRefNode("def " + methodName + "(...)", methodFullName, lineAndColumn) - - val methodNode = - createMethod( - methodName, - methodFullName, - scopeName, - parameterProvider, - bodyProvider, - returns, - isAsync = true, - Some(methodRefNode), - returnTypeHint = None, - lineAndColumn - ) - - (methodNode, methodRefNode) - end createMethodAndMethodRef - - // It is important that the nodes returned by all provider function are created - // during the function invocation and not in advance. Because only - // than the context information is correct. - private def createMethod( - name: String, - fullName: String, - scopeName: Option[String], - parameterProvider: () => MethodParameters, - bodyProvider: () => Iterable[nodes.NewNode], - returns: Option[ast.iexpr], - isAsync: Boolean, - methodRefNode: Option[nodes.NewMethodRef], - returnTypeHint: Option[String], - lineAndColumn: LineAndColumn - ): nodes.NewMethod = - val methodNode = nodeBuilder.methodNode(name, fullName, relFileName, lineAndColumn) - edgeBuilder.astEdge(methodNode, contextStack.astParent, contextStack.order.getAndInc) - - val blockNode = nodeBuilder.blockNode("", lineAndColumn) - edgeBuilder.astEdge(blockNode, methodNode, 1) - - contextStack.pushMethod(scopeName, methodNode, blockNode, methodRefNode) - - val virtualModifierNode = nodeBuilder.modifierNode(ModifierTypes.VIRTUAL) - edgeBuilder.astEdge(virtualModifierNode, methodNode, 0) - - val methodParameter = parameterProvider() - val parameterOrder = new AutoIncIndex(methodParameter.posStartIndex) - - methodParameter.positionalParams.foreach { parameterNode => - contextStack.addParameter(parameterNode) - edgeBuilder.astEdge(parameterNode, methodNode, parameterOrder.getAndInc) - } - - val methodReturnNode = - nodeBuilder.methodReturnNode( - nodeBuilder.extractTypesFromHint(returns), - returnTypeHint, - lineAndColumn - ) - edgeBuilder.astEdge(methodReturnNode, methodNode, 2) - - val bodyOrder = new AutoIncIndex(1) - bodyProvider().foreach { bodyStmt => - edgeBuilder.astEdge(bodyStmt, blockNode, bodyOrder.getAndInc) - } - - // For every method we create a corresponding TYPE and TYPE_DECL and - // a binding for the method into TYPE_DECL. - val typeNode = nodeBuilder.typeNode(name, fullName) - val typeDeclNode = - nodeBuilder.typeDeclNode(name, fullName, relFileName, Seq(Constants.ANY), lineAndColumn) - - // For every method that is a module, the local variables can be imported by other modules. This behaviour is - // much like fields so they are to be linked as fields to this method type - if name == "" then contextStack.createMemberLinks(typeDeclNode, edgeBuilder.astEdge) - - contextStack.pop() - edgeBuilder.astEdge(typeDeclNode, contextStack.astParent, contextStack.order.getAndInc) - createBinding(methodNode, typeDeclNode) - - methodNode - end createMethod - - // For a classDef we do: - // 1. Create a metaType, metaTypeDecl and metaTypeRef. - // 2. Create a function containing the code of the classDef body. - // 3. Create a block which contains a call to the body function - // and an assignment of the metaTypeRef to an identifier with the class name. - // 4. Create type and typeDecl for the instance class. - // 5. Create and link members in metaTypeDecl and instanceTypeDecl - def convert(classDef: ast.ClassDef): NewNode = - // Create type for the meta class object - val metaTypeDeclName = classDef.name + metaClassSuffix - val metaTypeDeclFullName = calculateFullNameFromContext(metaTypeDeclName) - - val metaTypeNode = nodeBuilder.typeNode(metaTypeDeclName, metaTypeDeclFullName) - val metaTypeDeclNode = - nodeBuilder.typeDeclNode( - metaTypeDeclName, - metaTypeDeclFullName, - relFileName, - Seq(Constants.ANY), - lineAndColOf(classDef) - ) - edgeBuilder.astEdge(metaTypeDeclNode, contextStack.astParent, contextStack.order.getAndInc) - - // Create type for class instances - val instanceTypeDeclName = classDef.name - val instanceTypeDeclFullName = calculateFullNameFromContext(instanceTypeDeclName) - - // TODO for now we just take the code of the base expression and pretend they are full names, converting special - // nodes as we go. - def handleInheritance(fs: List[ast.iexpr]): List[String] = fs match - case (x: ast.Call) :: xs => - val node = convert(x) - val parent = contextStack.astParent - val tmpVar = createIdentifierNode(getUnusedName(), Store, lineAndColOf(x)) - val assignment = createAssignment(tmpVar, node, lineAndColOf(x)) - diffGraph.addEdge(parent, assignment, EdgeTypes.AST) - tmpVar.name +: handleInheritance(xs) - case x :: xs => - nodeToCode.getCode(x) +: handleInheritance(xs) - case Nil => Nil - - val inheritsFrom = handleInheritance(classDef.bases.toList) - - val instanceType = nodeBuilder.typeNode(instanceTypeDeclName, instanceTypeDeclFullName) - val instanceTypeDecl = - nodeBuilder.typeDeclNode( - instanceTypeDeclName, - instanceTypeDeclFullName, - relFileName, - inheritsFrom, - lineAndColOf(classDef) - ) - edgeBuilder.astEdge(instanceTypeDecl, contextStack.astParent, contextStack.order.getAndInc) - - // Create function which contains the code defining the class - contextStack.pushClass(Some(classDef.name), instanceTypeDecl) - val classBodyFunctionName = "" - val (_, methodRefNode) = createMethodAndMethodRef( - classBodyFunctionName, - scopeName = None, - parameterProvider = () => MethodParameters.empty(), - bodyProvider = () => classDef.body.map(convert), - None, - isAsync = false, - lineAndColOf(classDef) - ) - - contextStack.pop() - - contextStack.pushClass(Some(classDef.name), metaTypeDeclNode) - - // Create meta class call handling method and bind it to meta class type. - val functions = classDef.body.collect { case func: ast.FunctionDef => func } - - // __init__ method has to be in functions because "async def __init__" is invalid. - val initFunctionOption = functions.find(_.name == "__init__") - - val initParameters = initFunctionOption.map(_.args).getOrElse { - // Create arguments of a default __init__ function. - ast.Arguments( - posonlyargs = mutable.Seq.empty[ast.Arg], - args = mutable.Seq(ast.Arg("self", None, None, classDef.attributeProvider)), - vararg = None, - kwonlyargs = mutable.Seq.empty[ast.Arg], - kw_defaults = mutable.Seq.empty[Option[ast.iexpr]], - kw_arg = None, - defaults = mutable.Seq.empty[ast.iexpr] - ) - } - - val metaClassCallHandlerMethod = - createMetaClassCallHandlerMethod( - initParameters, - metaTypeDeclName, - metaTypeDeclFullName, - instanceTypeDeclFullName - ) - - createBinding(metaClassCallHandlerMethod, metaTypeDeclNode) - - // Create fake __new__ regardless whether there is an actual implementation in the code. - // We do this to model the __init__ call in a visible way for the data flow tracker. - // This is done because very often the __init__ call is hidden in a super().__new__ call - // and we cant yet handle super(). - val fakeNewMethod = createFakeNewMethod(initParameters) - - val fakeNewMember = nodeBuilder.memberNode("", fakeNewMethod.fullName) - edgeBuilder.astEdge(fakeNewMember, metaTypeDeclNode, contextStack.order.getAndInc) - - // Create binding into class instance type for each method. - // Also create bindings into meta class type to enable calls like "MyClass.func(obj, p1)". - // For non static methods we create an adapter method which basically only shifts the parameters - // one to the left and makes sure that the meta class object is not passed to func as instance - // parameter. - classDef.body.foreach { - case func: ast.FunctionDef => - createMemberBindingsAndAdapter( - func, - func.name, - func.args, - func.decorator_list, - instanceTypeDecl, - metaTypeDeclNode - ) - case func: ast.AsyncFunctionDef => - createMemberBindingsAndAdapter( - func, - func.name, - func.args, - func.decorator_list, - instanceTypeDecl, - metaTypeDeclNode - ) - case _ => - // All other body statements are currently ignored. - } - - contextStack.pop() - - // Create call to function and assignment of the meta class object to a identifier named - // like the class. - val callToClassBodyFunction = - createCall(methodRefNode, "", lineAndColOf(classDef), Nil, Nil) - val metaTypeRefNode = - createTypeRef(metaTypeDeclName, metaTypeDeclFullName, lineAndColOf(classDef)) - val classIdentifierAssignNode = - createAssignmentToIdentifier(classDef.name, metaTypeRefNode, lineAndColOf(classDef)) - - val classBlock = createBlock( - callToClassBodyFunction :: classIdentifierAssignNode :: Nil, - lineAndColOf(classDef) - ) - - classBlock - end convert - - private def createMemberBindingsAndAdapter( - function: ast.istmt, - functionName: String, - functionArgs: ast.Arguments, - functionDecoratorList: Iterable[ast.iexpr], - instanceTypeDecl: nodes.NewNode, - metaTypeDecl: nodes.NewNode - ): Unit = - val memberForInstance = - nodeBuilder.memberNode( - functionName, - functionDefToMethod.apply(function).fullName, - lineAndColOf(function) - ) - edgeBuilder.astEdge(memberForInstance, instanceTypeDecl, contextStack.order.getAndInc) - - val methodForMetaClass = - if isStaticMethod(functionDecoratorList) || isClassMethod(functionDecoratorList) then - functionDefToMethod.apply(function) - else - createMetaClassAdapterMethod( - functionName, - functionDefToMethod.apply(function).fullName, - functionArgs, - lineAndColOf(function) - ) - - val memberForMeta = nodeBuilder.memberNode( - functionName, - methodForMetaClass.fullName, - lineAndColOf(function) - ) - edgeBuilder.astEdge(memberForMeta, metaTypeDecl, contextStack.order.getAndInc) - end createMemberBindingsAndAdapter - - /** Creates an adapter method which adapts the meta class version of a method to the instance - * class version. Consider class: class MyClass(): def func(self, p1): pass - * - * The syntax to call func via the meta class is: MyClass.func(someInstance, p1), whereas the - * call via the instance itself is: someInstance.func(p1). To adapt between those two we - * generate: def func(cls, self, p1): return STATIC_CALL(MyClass.func(self, - * p1)) - * @return - */ - // TODO handle kwArg - private def createMetaClassAdapterMethod( - adaptedMethodName: String, - adaptedMethodFullName: String, - parameters: ast.Arguments, - lineAndColumn: LineAndColumn - ): nodes.NewMethod = - val adapterMethodName = adaptedMethodName + "" - val adapterMethodFullName = calculateFullNameFromContext(adapterMethodName) + val firstLineAndCol = module.stmts.headOption.map(lineAndColOf) + val lastLineAndCol = module.stmts.lastOption.map(lineAndColOf) + val line = firstLineAndCol.map(_.line).getOrElse(1) + val column = firstLineAndCol.map(_.column).getOrElse(1) + val endLine = lastLineAndCol.map(_.endLine).getOrElse(1) + val endColumn = lastLineAndCol.map(_.endColumn).getOrElse(1) + val moduleMethodNode = createMethod( - adapterMethodName, - adapterMethodFullName, - Some(adaptedMethodName), - parameterProvider = () => - MethodParameters( - 0, - nodeBuilder.methodParameterNode( - "cls", - isVariadic = false, - lineAndColumn, - Option(0) - ) :: Nil ++ - convert(parameters, 1) - ), + "", + methodFullName, + Some(""), + parameterProvider = () => MethodParameters.empty(), bodyProvider = () => - val (arguments, keywordArguments) = createArguments(parameters, lineAndColumn) - val staticCall = - createStaticCall( - adaptedMethodName, - adaptedMethodFullName, - lineAndColumn, - arguments, - keywordArguments - ) - val returnNode = createReturn(Some(staticCall), None, lineAndColumn) - returnNode :: Nil - , + createBuiltinIdentifiers(memOpCalculator.names) ++ module.stmts.map(convert), returns = None, isAsync = false, methodRefNode = None, returnTypeHint = None, - lineAndColumn + LineAndColumn(line, column, endLine, endColumn) ) - end createMetaClassAdapterMethod - - def createArguments( - arguments: ast.Arguments, - lineAndColumn: LineAndColumn - ): (Iterable[nodes.NewNode], Iterable[(String, nodes.NewNode)]) = - val convertedArgs = mutable.ArrayBuffer.empty[nodes.NewNode] - val convertedKeywordArgs = mutable.ArrayBuffer.empty[(String, nodes.NewNode)] - arguments.posonlyargs.foreach { arg => - convertedArgs.append(createIdentifierNode(arg.arg, Load, lineAndColumn)) - } - arguments.args.foreach { arg => - convertedArgs.append(createIdentifierNode(arg.arg, Load, lineAndColumn)) - } - arguments.vararg.foreach { arg => - convertedArgs.append( - createStarredUnpackOperatorCall( - createIdentifierNode(arg.arg, Load, lineAndColumn), - lineAndColumn - ) - ) - } - arguments.kwonlyargs.foreach { arg => - convertedKeywordArgs.append(( - arg.arg, - createIdentifierNode(arg.arg, Load, lineAndColumn) - )) - } - - (convertedArgs, convertedKeywordArgs) - end createArguments - - /** This function strips the first positional parameter from initParameters, if present. - * @return - * Parameters without first positional parameter and adjusted line and column number - * information. - */ - private def stripFirstPositionalParameter(initParameters: ast.Arguments) - : (ast.Arguments, LineAndColumn) = - if initParameters.posonlyargs.nonEmpty then - ( - initParameters.copy(posonlyargs = initParameters.posonlyargs.tail), - lineAndColOf(initParameters.posonlyargs.head) - ) - else if initParameters.args.nonEmpty then - ( - initParameters.copy(args = initParameters.args.tail), - lineAndColOf(initParameters.args.head) - ) - else if initParameters.vararg.nonEmpty then - (initParameters, lineAndColOf(initParameters.vararg.get)) + createIdentifierLinks() + + moduleMethodNode + end convert + + // Create assignments of type references to all builtin identifiers if the identifier appears + // at least once in the module. We filter in order to not generate the complete list of + // assignment in each module. + // That logic still generates assignments in cases where we would not need to, but figuring + // that out would mean we need to wait until all identifiers are linked which is than too + // late to create new identifiers and still use the same link mechanism. We would need + // to rearrange quite some code to accomplish that. So we leave that as an optional TODO. + // Note that namesUsedInModule is only calculated from ast.Name nodes! So e.g. new names + // artificially created during lowering are not in that collection which is fine for now. + private def createBuiltinIdentifiers(namesUsedInModule: collection.Set[String]) + : Iterable[nodes.NewNode] = + val result = mutable.ArrayBuffer.empty[nodes.NewNode] + val lineAndColumn = LineAndColumn(1, 1, 1, 1) + + val builtinFunctions = mutable.ArrayBuffer.empty[String] + val builtinClasses = mutable.ArrayBuffer.empty[String] + + if version == PythonV3 || version == PythonV2AndV3 then + builtinFunctions.appendAll(PythonAstVisitor.builtinFunctionsV3) + builtinClasses.appendAll(PythonAstVisitor.builtinClassesV3) + if version == PythonV2 || version == PythonV2AndV3 then + builtinFunctions.appendAll(PythonAstVisitor.builtinFunctionsV2) + builtinClasses.appendAll(PythonAstVisitor.builtinClassesV2) + + builtinFunctions.distinct.foreach { builtinObjectName => + if namesUsedInModule.contains(builtinObjectName) then + val assignmentNode = createAssignment( + createIdentifierNode(builtinObjectName, Store, lineAndColumn), + nodeBuilder + .typeRefNode( + "__builtins__." + builtinObjectName, + builtinPrefix + builtinObjectName, + lineAndColumn + ), + lineAndColumn + ) + + result.append(assignmentNode) + } + + builtinClasses.distinct.foreach { builtinObjectName => + if namesUsedInModule.contains(builtinObjectName) then + val assignmentNode = createAssignment( + createIdentifierNode(builtinObjectName, Store, lineAndColumn), + nodeBuilder.typeRefNode( + "__builtins__." + builtinObjectName, + builtinPrefix + builtinObjectName + metaClassSuffix, + lineAndColumn + ), + lineAndColumn + ) + + result.append(assignmentNode) + } + + result + end createBuiltinIdentifiers + + private def unhandled(node: ast.iast & ast.iattributes): NewNode = + val unhandledAsUnknown = true + if unhandledAsUnknown then + nodeBuilder.unknownNode(node.toString, node.getClass.getName, lineAndColOf(node)) + else + throw new NotImplementedError() + + def convert(stmt: ast.istmt): NewNode = + stmt match + case node: ast.FunctionDef => convert(node) + case node: ast.AsyncFunctionDef => convert(node) + case node: ast.ClassDef => convert(node) + case node: ast.Return => convert(node) + case node: ast.Delete => convert(node) + case node: ast.Assign => convert(node) + case node: ast.AnnAssign => convert(node) + case node: ast.AugAssign => convert(node) + case node: ast.For => convert(node) + case node: ast.AsyncFor => convert(node) + case node: ast.While => convert(node) + case node: ast.If => convert(node) + case node: ast.With => convert(node) + case node: ast.AsyncWith => convert(node) + case node: ast.Match => convert(node) + case node: ast.Raise => convert(node) + case node: ast.Try => convert(node) + case node: ast.Assert => convert(node) + case node: ast.Import => convert(node) + case node: ast.ImportFrom => convert(node) + case node: ast.Global => convert(node) + case node: ast.Nonlocal => convert(node) + case node: ast.Expr => convert(node) + case node: ast.Pass => convert(node) + case node: ast.Break => convert(node) + case node: ast.Continue => convert(node) + case node: ast.RaiseP2 => unhandled(node) + case node: ast.ErrorStatement => convert(node) + + def convert(functionDef: ast.FunctionDef): NewNode = + val methodIdentifierNode = + createIdentifierNode(functionDef.name, Store, lineAndColOf(functionDef)) + val (methodNode, methodRefNode) = createMethodAndMethodRef( + functionDef.name, + Some(functionDef.name), + createParameterProcessingFunction( + functionDef.args, + isStaticMethod(functionDef.decorator_list) + ), + () => functionDef.body.map(convert), + functionDef.returns, + isAsync = false, + lineAndColOf(functionDef) + ) + functionDefToMethod.put(functionDef, methodNode) + + val wrappedMethodRefNode = + wrapMethodRefWithDecorators(methodRefNode, functionDef.decorator_list) + + createAssignment(methodIdentifierNode, wrappedMethodRefNode, lineAndColOf(functionDef)) + end convert + + /* + * For a decorated function like: + * @f1(arg) + * @f2 + * def func(): pass + * + * The lowering is: + * func = f1(arg)(f2(func)) + * + * This function takes a method ref, wraps it in the decorator calls and returns the resulting expression. + * In the example case this is: + * f1(arg)(f2(func)) + */ + def wrapMethodRefWithDecorators( + methodRefNode: nodes.NewNode, + decoratorList: Iterable[ast.iexpr] + ): nodes.NewNode = + decoratorList.foldRight(methodRefNode)( + (decorator: ast.iexpr, wrappedMethodRef: nodes.NewNode) => + val (decoratorNode, decoratorName) = convert(decorator) match + case decoratorNode: NewIdentifier => decoratorNode -> decoratorNode.name + case decoratorNode => + decoratorNode -> "" // other decorators are dynamic so we leave this blank + createCall( + decoratorNode, + decoratorName, + lineAndColOf(decorator), + wrappedMethodRef :: Nil, + Nil + ) + ) + + def convert(functionDef: ast.AsyncFunctionDef): NewNode = + val methodIdentifierNode = + createIdentifierNode(functionDef.name, Store, lineAndColOf(functionDef)) + val (methodNode, methodRefNode) = createMethodAndMethodRef( + functionDef.name, + Some(functionDef.name), + createParameterProcessingFunction( + functionDef.args, + isStaticMethod(functionDef.decorator_list) + ), + () => functionDef.body.map(convert), + functionDef.returns, + isAsync = true, + lineAndColOf(functionDef) + ) + functionDefToMethod.put(functionDef, methodNode) + + val wrappedMethodRefNode = + wrapMethodRefWithDecorators(methodRefNode, functionDef.decorator_list) + + createAssignment(methodIdentifierNode, wrappedMethodRefNode, lineAndColOf(functionDef)) + end convert + + private def isStaticMethod(decoratorList: Iterable[ast.iexpr]): Boolean = + decoratorList.exists { + case name: ast.Name if name.id == "staticmethod" => true + case _ => false + } + + private def isClassMethod(decoratorList: Iterable[ast.iexpr]): Boolean = + decoratorList.exists { + case name: ast.Name if name.id == "classmethod" => true + case _ => false + } + + private def createParameterProcessingFunction( + parameters: ast.Arguments, + isStatic: Boolean + ): () => MethodParameters = + val startIndex = + if contextStack.isClassContext && !isStatic then + 0 else - (initParameters, lineAndColOf(initParameters.kw_arg.get)) - - /** Creates the method which handles a call to the meta class object. This process is also known - * as creating a new instance object, e.g. obj = MyClass(p1). The purpose of the generated - * function is to adapt between the special cased instance creation call and a normal call to - * __new__ (for now ). The adaption is required to in order to provide TYPE_REF(meta - * class) as instance argument to __new__/. So the looks like: - * def (p1): return DYNAMIC_CALL(receiver=TYPE_REF(meta class)., - * instance \= TYPE_REF(meta class), p1) - */ - // TODO handle kwArg - private def createMetaClassCallHandlerMethod( - initParameters: ast.Arguments, - metaTypeDeclName: String, - metaTypeDeclFullName: String, - instanceTypeDeclFullName: String - ): nodes.NewMethod = - val methodName = "" - val methodFullName = calculateFullNameFromContext(methodName) - - // We need to drop the "self" parameter either from the position only or normal parameters - // because "self" is not passed through but rather created in __new__. - val (parametersWithoutSelf, lineAndColumn) = stripFirstPositionalParameter(initParameters) - + 1 + + () => new MethodParameters(startIndex, convert(parameters, startIndex)) + + // TODO handle returns + private def createMethodAndMethodRef( + methodName: String, + scopeName: Option[String], + parameterProvider: () => MethodParameters, + bodyProvider: () => Iterable[nodes.NewNode], + returns: Option[ast.iexpr], + isAsync: Boolean, + lineAndColumn: LineAndColumn + ): (nodes.NewMethod, nodes.NewMethodRef) = + val methodFullName = calculateFullNameFromContext(methodName) + + val methodRefNode = + nodeBuilder.methodRefNode("def " + methodName + "(...)", methodFullName, lineAndColumn) + + val methodNode = createMethod( methodName, methodFullName, - Some(methodName), - parameterProvider = () => - MethodParameters(1, convert(parametersWithoutSelf, 1)), - bodyProvider = () => - val (arguments, keywordArguments) = - createArguments(parametersWithoutSelf, lineAndColumn) - - val fakeNewCall = createInstanceCall( - createFieldAccess( - createTypeRef(metaTypeDeclName, metaTypeDeclFullName, lineAndColumn), - "", - lineAndColumn - ), - createTypeRef(metaTypeDeclName, metaTypeDeclFullName, lineAndColumn), - "", - lineAndColumn, - arguments, - keywordArguments - ) - - val returnNode = createReturn(Some(fakeNewCall), None, lineAndColumn) - - returnNode :: Nil - , - returns = None, - isAsync = false, - methodRefNode = None, - returnTypeHint = Some(instanceTypeDeclFullName), + scopeName, + parameterProvider, + bodyProvider, + returns, + isAsync = true, + Some(methodRefNode), + returnTypeHint = None, lineAndColumn ) - end createMetaClassCallHandlerMethod - - /** Creates a method which mimics the behaviour of a default __new__ method (the one - * you would get if no implementation is present). The reason we use a fake version of the - * __new__ method it that we wont be able to correctly track through most custom __new__ - * implementations as they usually call "super.__init__()" and we cannot yet handle "super". - * The fake __new__ looks like: def (cls, p1): __newInstance = - * STATIC_CALL(.alloc) cls.__init__(__newIstance, p1) return __newInstance - */ - // TODO handle kwArg - private def createFakeNewMethod(initParameters: ast.Arguments): nodes.NewMethod = - val newMethodName = "" - val newMethodStubFullName = calculateFullNameFromContext(newMethodName) - - // We need to drop the "self" parameter either from the position only or normal parameters - // because "self" is not passed through but rather created in __new__. - val (parametersWithoutSelf, lineAndColumn) = stripFirstPositionalParameter(initParameters) - createMethod( - newMethodName, - newMethodStubFullName, - Some(newMethodName), - parameterProvider = () => - MethodParameters( - 0, - nodeBuilder.methodParameterNode( - "cls", - isVariadic = false, - lineAndColumn, - Some(0) - ) :: Nil ++ - convert(parametersWithoutSelf, 1) - ), - bodyProvider = () => - val allocatorCall = - createNAryOperatorCall( - () => (".alloc", ".alloc"), - Nil, - lineAndColumn - ) - val assignmentToNewInstance = - createAssignment( - createIdentifierNode("__newInstance", Store, lineAndColumn), - allocatorCall, - lineAndColumn - ) - - val (arguments, keywordArguments) = - createArguments(parametersWithoutSelf, lineAndColumn) - val argumentWithInstance = mutable.ArrayBuffer.empty[nodes.NewNode] - argumentWithInstance.append(createIdentifierNode( - "__newInstance", - Load, - lineAndColumn - )) - argumentWithInstance.appendAll(arguments) - - val initCall = createXDotYCall( - () => createIdentifierNode("cls", Load, lineAndColumn), - "__init__", - xMayHaveSideEffects = false, - lineAndColumn, - argumentWithInstance, - keywordArguments - ) - - val returnNode = - createReturn( - Some(createIdentifierNode("__newInstance", Load, lineAndColumn)), - None, - lineAndColumn - ) - - assignmentToNewInstance :: initCall :: returnNode :: Nil - , - returns = None, - isAsync = false, - methodRefNode = None, - returnTypeHint = None, + (methodNode, methodRefNode) + end createMethodAndMethodRef + + // It is important that the nodes returned by all provider function are created + // during the function invocation and not in advance. Because only + // than the context information is correct. + private def createMethod( + name: String, + fullName: String, + scopeName: Option[String], + parameterProvider: () => MethodParameters, + bodyProvider: () => Iterable[nodes.NewNode], + returns: Option[ast.iexpr], + isAsync: Boolean, + methodRefNode: Option[nodes.NewMethodRef], + returnTypeHint: Option[String], + lineAndColumn: LineAndColumn + ): nodes.NewMethod = + val methodNode = nodeBuilder.methodNode(name, fullName, relFileName, lineAndColumn) + edgeBuilder.astEdge(methodNode, contextStack.astParent, contextStack.order.getAndInc) + + val blockNode = nodeBuilder.blockNode("", lineAndColumn) + edgeBuilder.astEdge(blockNode, methodNode, 1) + + contextStack.pushMethod(scopeName, methodNode, blockNode, methodRefNode) + + val virtualModifierNode = nodeBuilder.modifierNode(ModifierTypes.VIRTUAL) + edgeBuilder.astEdge(virtualModifierNode, methodNode, 0) + + val methodParameter = parameterProvider() + val parameterOrder = new AutoIncIndex(methodParameter.posStartIndex) + + methodParameter.positionalParams.foreach { parameterNode => + contextStack.addParameter(parameterNode) + edgeBuilder.astEdge(parameterNode, methodNode, parameterOrder.getAndInc) + } + + val methodReturnNode = + nodeBuilder.methodReturnNode( + nodeBuilder.extractTypesFromHint(returns), + returnTypeHint, lineAndColumn ) - end createFakeNewMethod - - def convert(ret: ast.Return): NewNode = - createReturn(ret.value.map(convert), Some(nodeToCode.getCode(ret)), lineAndColOf(ret)) - - def convert(delete: ast.Delete): NewNode = - val deleteArgs = delete.targets.map(convert) - - val code = "del " + deleteArgs.map(codeOf).mkString(", ") - val callNode = nodeBuilder.callNode( - code, - ".delete", - DispatchTypes.STATIC_DISPATCH, - lineAndColOf(delete) + edgeBuilder.astEdge(methodReturnNode, methodNode, 2) + + val bodyOrder = new AutoIncIndex(1) + bodyProvider().foreach { bodyStmt => + edgeBuilder.astEdge(bodyStmt, blockNode, bodyOrder.getAndInc) + } + + // For every method we create a corresponding TYPE and TYPE_DECL and + // a binding for the method into TYPE_DECL. + val typeNode = nodeBuilder.typeNode(name, fullName) + val typeDeclNode = + nodeBuilder.typeDeclNode(name, fullName, relFileName, Seq(Constants.ANY), lineAndColumn) + + // For every method that is a module, the local variables can be imported by other modules. This behaviour is + // much like fields so they are to be linked as fields to this method type + if name == "" then contextStack.createMemberLinks(typeDeclNode, edgeBuilder.astEdge) + + contextStack.pop() + edgeBuilder.astEdge(typeDeclNode, contextStack.astParent, contextStack.order.getAndInc) + createBinding(methodNode, typeDeclNode) + + methodNode + end createMethod + + // For a classDef we do: + // 1. Create a metaType, metaTypeDecl and metaTypeRef. + // 2. Create a function containing the code of the classDef body. + // 3. Create a block which contains a call to the body function + // and an assignment of the metaTypeRef to an identifier with the class name. + // 4. Create type and typeDecl for the instance class. + // 5. Create and link members in metaTypeDecl and instanceTypeDecl + def convert(classDef: ast.ClassDef): NewNode = + // Create type for the meta class object + val metaTypeDeclName = classDef.name + metaClassSuffix + val metaTypeDeclFullName = calculateFullNameFromContext(metaTypeDeclName) + + val metaTypeNode = nodeBuilder.typeNode(metaTypeDeclName, metaTypeDeclFullName) + val metaTypeDeclNode = + nodeBuilder.typeDeclNode( + metaTypeDeclName, + metaTypeDeclFullName, + relFileName, + Seq(Constants.ANY), + lineAndColOf(classDef) ) + edgeBuilder.astEdge(metaTypeDeclNode, contextStack.astParent, contextStack.order.getAndInc) + + // Create type for class instances + val instanceTypeDeclName = classDef.name + val instanceTypeDeclFullName = calculateFullNameFromContext(instanceTypeDeclName) + + // TODO for now we just take the code of the base expression and pretend they are full names, converting special + // nodes as we go. + def handleInheritance(fs: List[ast.iexpr]): List[String] = fs match + case (x: ast.Call) :: xs => + val node = convert(x) + val parent = contextStack.astParent + val tmpVar = createIdentifierNode(getUnusedName(), Store, lineAndColOf(x)) + val assignment = createAssignment(tmpVar, node, lineAndColOf(x)) + diffGraph.addEdge(parent, assignment, EdgeTypes.AST) + tmpVar.name +: handleInheritance(xs) + case x :: xs => + nodeToCode.getCode(x) +: handleInheritance(xs) + case Nil => Nil + + val inheritsFrom = handleInheritance(classDef.bases.toList) + + val instanceType = nodeBuilder.typeNode(instanceTypeDeclName, instanceTypeDeclFullName) + val instanceTypeDecl = + nodeBuilder.typeDeclNode( + instanceTypeDeclName, + instanceTypeDeclFullName, + relFileName, + inheritsFrom, + lineAndColOf(classDef) + ) + edgeBuilder.astEdge(instanceTypeDecl, contextStack.astParent, contextStack.order.getAndInc) + + // Create function which contains the code defining the class + contextStack.pushClass(Some(classDef.name), instanceTypeDecl) + val classBodyFunctionName = "" + val (_, methodRefNode) = createMethodAndMethodRef( + classBodyFunctionName, + scopeName = None, + parameterProvider = () => MethodParameters.empty(), + bodyProvider = () => classDef.body.map(convert), + None, + isAsync = false, + lineAndColOf(classDef) + ) - addAstChildrenAsArguments(callNode, 1, deleteArgs) - callNode + contextStack.pop() - def convert(assign: ast.Assign): nodes.NewNode = - val loweredNodes = - createValueToTargetsDecomposition( - assign.targets, - convert(assign.value), - lineAndColOf(assign) - ) + contextStack.pushClass(Some(classDef.name), metaTypeDeclNode) - if loweredNodes.size == 1 then - // Simple assignment can be returned directly. - loweredNodes.head - else - createBlock(loweredNodes, lineAndColOf(assign)) + // Create meta class call handling method and bind it to meta class type. + val functions = classDef.body.collect { case func: ast.FunctionDef => func } - // TODO for now we ignore the annotation part and just emit the pure - // assignment. - def convert(annotatedAssign: ast.AnnAssign): NewNode = - val targetNode = convert(annotatedAssign.target) + // __init__ method has to be in functions because "async def __init__" is invalid. + val initFunctionOption = functions.find(_.name == "__init__") - annotatedAssign.value match - case Some(value) => - val valueNode = convert(value) - createAssignment(targetNode, valueNode, lineAndColOf(annotatedAssign)) - case None => - // If there is no value, this is just an expr: annotation and since - // we for now ignore the annotation we emit just the expr because - // it may have side effects. - targetNode - - def convert(augAssign: ast.AugAssign): NewNode = - val targetNode = convert(augAssign.target) - val valueNode = convert(augAssign.value) - - val (operatorCode, operatorFullName) = - augAssign.op match - case ast.Add => ("+=", Operators.assignmentPlus) - case ast.Sub => ("-=", Operators.assignmentMinus) - case ast.Mult => ("*=", Operators.assignmentMultiplication) - case ast.MatMult => - ( - "@=", - ".assignmentMatMult" - ) // TODO make this a define and add policy for this - case ast.Div => ("/=", Operators.assignmentDivision) - case ast.Mod => ("%=", Operators.assignmentModulo) - case ast.Pow => ("**=", Operators.assignmentExponentiation) - case ast.LShift => ("<<=", Operators.assignmentShiftLeft) - case ast.RShift => ("<<=", Operators.assignmentArithmeticShiftRight) - case ast.BitOr => ("|=", Operators.assignmentOr) - case ast.BitXor => ("^=", Operators.assignmentXor) - case ast.BitAnd => ("&=", Operators.assignmentAnd) - case ast.FloorDiv => - ( - "//=", - ".assignmentFloorDiv" - ) // TODO make this a define and add policy for this - - createAugAssignment( - targetNode, - operatorCode, - valueNode, - operatorFullName, - lineAndColOf(augAssign) + val initParameters = initFunctionOption.map(_.args).getOrElse { + // Create arguments of a default __init__ function. + ast.Arguments( + posonlyargs = mutable.Seq.empty[ast.Arg], + args = mutable.Seq(ast.Arg("self", None, None, classDef.attributeProvider)), + vararg = None, + kwonlyargs = mutable.Seq.empty[ast.Arg], + kw_defaults = mutable.Seq.empty[Option[ast.iexpr]], + kw_arg = None, + defaults = mutable.Seq.empty[ast.iexpr] ) - end convert - - // TODO write test - def convert(forStmt: ast.For): NewNode = - createForLowering( - forStmt.target, - forStmt.iter, - Iterable.empty, - forStmt.body.map(convert), - forStmt.orelse.map(convert), - isAsync = false, - lineAndColOf(forStmt) + } + + val metaClassCallHandlerMethod = + createMetaClassCallHandlerMethod( + initParameters, + metaTypeDeclName, + metaTypeDeclFullName, + instanceTypeDeclFullName ) - def convert(forStmt: ast.AsyncFor): NewNode = - createForLowering( - forStmt.target, - forStmt.iter, - Iterable.empty, - forStmt.body.map(convert), - forStmt.orelse.map(convert), - isAsync = true, - lineAndColOf(forStmt) + createBinding(metaClassCallHandlerMethod, metaTypeDeclNode) + + // Create fake __new__ regardless whether there is an actual implementation in the code. + // We do this to model the __init__ call in a visible way for the data flow tracker. + // This is done because very often the __init__ call is hidden in a super().__new__ call + // and we cant yet handle super(). + val fakeNewMethod = createFakeNewMethod(initParameters) + + val fakeNewMember = nodeBuilder.memberNode("", fakeNewMethod.fullName) + edgeBuilder.astEdge(fakeNewMember, metaTypeDeclNode, contextStack.order.getAndInc) + + // Create binding into class instance type for each method. + // Also create bindings into meta class type to enable calls like "MyClass.func(obj, p1)". + // For non static methods we create an adapter method which basically only shifts the parameters + // one to the left and makes sure that the meta class object is not passed to func as instance + // parameter. + classDef.body.foreach { + case func: ast.FunctionDef => + createMemberBindingsAndAdapter( + func, + func.name, + func.args, + func.decorator_list, + instanceTypeDecl, + metaTypeDeclNode + ) + case func: ast.AsyncFunctionDef => + createMemberBindingsAndAdapter( + func, + func.name, + func.args, + func.decorator_list, + instanceTypeDecl, + metaTypeDeclNode + ) + case _ => + // All other body statements are currently ignored. + } + + contextStack.pop() + + // Create call to function and assignment of the meta class object to a identifier named + // like the class. + val callToClassBodyFunction = + createCall(methodRefNode, "", lineAndColOf(classDef), Nil, Nil) + val metaTypeRefNode = + createTypeRef(metaTypeDeclName, metaTypeDeclFullName, lineAndColOf(classDef)) + val classIdentifierAssignNode = + createAssignmentToIdentifier(classDef.name, metaTypeRefNode, lineAndColOf(classDef)) + + val classBlock = createBlock( + callToClassBodyFunction :: classIdentifierAssignNode :: Nil, + lineAndColOf(classDef) + ) + + classBlock + end convert + + private def createMemberBindingsAndAdapter( + function: ast.istmt, + functionName: String, + functionArgs: ast.Arguments, + functionDecoratorList: Iterable[ast.iexpr], + instanceTypeDecl: nodes.NewNode, + metaTypeDecl: nodes.NewNode + ): Unit = + val memberForInstance = + nodeBuilder.memberNode( + functionName, + functionDefToMethod.apply(function).fullName, + lineAndColOf(function) ) + edgeBuilder.astEdge(memberForInstance, instanceTypeDecl, contextStack.order.getAndInc) - // Lowering of for x in y: : - // { - // iterator = y.__iter__() - // while (UNKNOWN condition): - // (x = iterator.__next__()) - // - // } - // If one "if" is present the lowering of for x in y if z: - // { - // iterator = y.__iter__() - // while (UNKNOWN condition): - // if (!z): continue - // (x = iterator.__next__()) - // - // } - // If multiple "ifs" are present the lowering of for x in y if z if a: ..,: - // { - // iterator = y.__iter__() - // while (UNKNOWN condition): - // if (!(z and a)): continue - // (x = iterator.__next__()) - // - // } - protected def createForLowering( - target: ast.iexpr, - iter: ast.iexpr, - ifs: Iterable[ast.iexpr], - bodyNodes: Iterable[nodes.NewNode], - orelseNodes: Iterable[nodes.NewNode], - isAsync: Boolean, - lineAndColumn: LineAndColumn - ): nodes.NewNode = - val iterVariableName = getUnusedName() - val iterExprIterCallNode = - createXDotYCall( - () => convert(iter), - "__iter__", - xMayHaveSideEffects = !iter.isInstanceOf[ast.Name], + val methodForMetaClass = + if isStaticMethod(functionDecoratorList) || isClassMethod(functionDecoratorList) then + functionDefToMethod.apply(function) + else + createMetaClassAdapterMethod( + functionName, + functionDefToMethod.apply(function).fullName, + functionArgs, + lineAndColOf(function) + ) + + val memberForMeta = nodeBuilder.memberNode( + functionName, + methodForMetaClass.fullName, + lineAndColOf(function) + ) + edgeBuilder.astEdge(memberForMeta, metaTypeDecl, contextStack.order.getAndInc) + end createMemberBindingsAndAdapter + + /** Creates an adapter method which adapts the meta class version of a method to the instance + * class version. Consider class: class MyClass(): def func(self, p1): pass + * + * The syntax to call func via the meta class is: MyClass.func(someInstance, p1), whereas the + * call via the instance itself is: someInstance.func(p1). To adapt between those two we + * generate: def func(cls, self, p1): return STATIC_CALL(MyClass.func(self, + * p1)) + * @return + */ + // TODO handle kwArg + private def createMetaClassAdapterMethod( + adaptedMethodName: String, + adaptedMethodFullName: String, + parameters: ast.Arguments, + lineAndColumn: LineAndColumn + ): nodes.NewMethod = + val adapterMethodName = adaptedMethodName + "" + val adapterMethodFullName = calculateFullNameFromContext(adapterMethodName) + + createMethod( + adapterMethodName, + adapterMethodFullName, + Some(adaptedMethodName), + parameterProvider = () => + MethodParameters( + 0, + nodeBuilder.methodParameterNode( + "cls", + isVariadic = false, lineAndColumn, - Nil, - Nil + Option(0) + ) :: Nil ++ + convert(parameters, 1) + ), + bodyProvider = () => + val (arguments, keywordArguments) = createArguments(parameters, lineAndColumn) + val staticCall = + createStaticCall( + adaptedMethodName, + adaptedMethodFullName, + lineAndColumn, + arguments, + keywordArguments ) - val iterAssignNode = - createAssignmentToIdentifier(iterVariableName, iterExprIterCallNode, lineAndColumn) - - val conditionNode = - nodeBuilder.unknownNode("iteratorNonEmptyOrException", "", lineAndColumn) + val returnNode = createReturn(Some(staticCall), None, lineAndColumn) + returnNode :: Nil + , + returns = None, + isAsync = false, + methodRefNode = None, + returnTypeHint = None, + lineAndColumn + ) + end createMetaClassAdapterMethod + + def createArguments( + arguments: ast.Arguments, + lineAndColumn: LineAndColumn + ): (Iterable[nodes.NewNode], Iterable[(String, nodes.NewNode)]) = + val convertedArgs = mutable.ArrayBuffer.empty[nodes.NewNode] + val convertedKeywordArgs = mutable.ArrayBuffer.empty[(String, nodes.NewNode)] + + arguments.posonlyargs.foreach { arg => + convertedArgs.append(createIdentifierNode(arg.arg, Load, lineAndColumn)) + } + arguments.args.foreach { arg => + convertedArgs.append(createIdentifierNode(arg.arg, Load, lineAndColumn)) + } + arguments.vararg.foreach { arg => + convertedArgs.append( + createStarredUnpackOperatorCall( + createIdentifierNode(arg.arg, Load, lineAndColumn), + lineAndColumn + ) + ) + } + arguments.kwonlyargs.foreach { arg => + convertedKeywordArgs.append(( + arg.arg, + createIdentifierNode(arg.arg, Load, lineAndColumn) + )) + } + + (convertedArgs, convertedKeywordArgs) + end createArguments + + /** This function strips the first positional parameter from initParameters, if present. + * @return + * Parameters without first positional parameter and adjusted line and column number + * information. + */ + private def stripFirstPositionalParameter(initParameters: ast.Arguments) + : (ast.Arguments, LineAndColumn) = + if initParameters.posonlyargs.nonEmpty then + ( + initParameters.copy(posonlyargs = initParameters.posonlyargs.tail), + lineAndColOf(initParameters.posonlyargs.head) + ) + else if initParameters.args.nonEmpty then + ( + initParameters.copy(args = initParameters.args.tail), + lineAndColOf(initParameters.args.head) + ) + else if initParameters.vararg.nonEmpty then + (initParameters, lineAndColOf(initParameters.vararg.get)) + else + (initParameters, lineAndColOf(initParameters.kw_arg.get)) + + /** Creates the method which handles a call to the meta class object. This process is also known + * as creating a new instance object, e.g. obj = MyClass(p1). The purpose of the generated + * function is to adapt between the special cased instance creation call and a normal call to + * __new__ (for now ). The adaption is required to in order to provide TYPE_REF(meta + * class) as instance argument to __new__/. So the looks like: + * def (p1): return DYNAMIC_CALL(receiver=TYPE_REF(meta class)., + * instance \= TYPE_REF(meta class), p1) + */ + // TODO handle kwArg + private def createMetaClassCallHandlerMethod( + initParameters: ast.Arguments, + metaTypeDeclName: String, + metaTypeDeclFullName: String, + instanceTypeDeclFullName: String + ): nodes.NewMethod = + val methodName = "" + val methodFullName = calculateFullNameFromContext(methodName) + + // We need to drop the "self" parameter either from the position only or normal parameters + // because "self" is not passed through but rather created in __new__. + val (parametersWithoutSelf, lineAndColumn) = stripFirstPositionalParameter(initParameters) + + createMethod( + methodName, + methodFullName, + Some(methodName), + parameterProvider = () => + MethodParameters(1, convert(parametersWithoutSelf, 1)), + bodyProvider = () => + val (arguments, keywordArguments) = + createArguments(parametersWithoutSelf, lineAndColumn) + + val fakeNewCall = createInstanceCall( + createFieldAccess( + createTypeRef(metaTypeDeclName, metaTypeDeclFullName, lineAndColumn), + "", + lineAndColumn + ), + createTypeRef(metaTypeDeclName, metaTypeDeclFullName, lineAndColumn), + "", + lineAndColumn, + arguments, + keywordArguments + ) - val controlStructureNode = - nodeBuilder.controlStructureNode( - "while ... : ...", - ControlStructureTypes.WHILE, - lineAndColumn - ) - edgeBuilder.conditionEdge(conditionNode, controlStructureNode) + val returnNode = createReturn(Some(fakeNewCall), None, lineAndColumn) - val iterNextCallNode = - createXDotYCall( - () => createIdentifierNode(iterVariableName, Load, lineAndColumn), - "__next__", - xMayHaveSideEffects = false, + returnNode :: Nil + , + returns = None, + isAsync = false, + methodRefNode = None, + returnTypeHint = Some(instanceTypeDeclFullName), + lineAndColumn + ) + end createMetaClassCallHandlerMethod + + /** Creates a method which mimics the behaviour of a default __new__ method (the one you + * would get if no implementation is present). The reason we use a fake version of the __new__ + * method it that we wont be able to correctly track through most custom __new__ implementations + * as they usually call "super.__init__()" and we cannot yet handle "super". The fake __new__ + * looks like: def (cls, p1): __newInstance = STATIC_CALL(.alloc) + * cls.__init__(__newIstance, p1) return __newInstance + */ + // TODO handle kwArg + private def createFakeNewMethod(initParameters: ast.Arguments): nodes.NewMethod = + val newMethodName = "" + val newMethodStubFullName = calculateFullNameFromContext(newMethodName) + + // We need to drop the "self" parameter either from the position only or normal parameters + // because "self" is not passed through but rather created in __new__. + val (parametersWithoutSelf, lineAndColumn) = stripFirstPositionalParameter(initParameters) + + createMethod( + newMethodName, + newMethodStubFullName, + Some(newMethodName), + parameterProvider = () => + MethodParameters( + 0, + nodeBuilder.methodParameterNode( + "cls", + isVariadic = false, lineAndColumn, + Some(0) + ) :: Nil ++ + convert(parametersWithoutSelf, 1) + ), + bodyProvider = () => + val allocatorCall = + createNAryOperatorCall( + () => (".alloc", ".alloc"), Nil, - Nil + lineAndColumn ) - - val loweredAssignNodes = - createValueToTargetsDecomposition( - Iterable.single(target), - iterNextCallNode, + val assignmentToNewInstance = + createAssignment( + createIdentifierNode("__newInstance", Store, lineAndColumn), + allocatorCall, lineAndColumn ) - val blockStmtNodes = mutable.ArrayBuffer.empty[nodes.NewNode] - blockStmtNodes.appendAll(loweredAssignNodes) - - if ifs.nonEmpty then - val conditionNode = - if ifs.size == 1 then - ifs.head - else - ast.BoolOp(ast.And, ifs.to(mutable.Seq), ifs.head.attributeProvider) - val ifNotContinueNode = convert( - ast.If( - ast.UnaryOp(ast.Not, conditionNode, ifs.head.attributeProvider), - mutable.ArrayBuffer.empty[ast.istmt].append( - ast.Continue(ifs.head.attributeProvider) - ), - mutable.Seq.empty[ast.istmt], - ifs.head.attributeProvider - ) - ) + val (arguments, keywordArguments) = + createArguments(parametersWithoutSelf, lineAndColumn) + val argumentWithInstance = mutable.ArrayBuffer.empty[nodes.NewNode] + argumentWithInstance.append(createIdentifierNode( + "__newInstance", + Load, + lineAndColumn + )) + argumentWithInstance.appendAll(arguments) - blockStmtNodes.append(ifNotContinueNode) - bodyNodes.foreach(blockStmtNodes.append) + val initCall = createXDotYCall( + () => createIdentifierNode("cls", Load, lineAndColumn), + "__init__", + xMayHaveSideEffects = false, + lineAndColumn, + argumentWithInstance, + keywordArguments + ) - val bodyBlockNode = createBlock(blockStmtNodes, lineAndColumn) - addAstChildNodes(controlStructureNode, 1, conditionNode, bodyBlockNode) + val returnNode = + createReturn( + Some(createIdentifierNode("__newInstance", Load, lineAndColumn)), + None, + lineAndColumn + ) - if orelseNodes.nonEmpty then - val elseBlockNode = createBlock(orelseNodes, lineAndColumn) - addAstChildNodes(controlStructureNode, 3, elseBlockNode) + assignmentToNewInstance :: initCall :: returnNode :: Nil + , + returns = None, + isAsync = false, + methodRefNode = None, + returnTypeHint = None, + lineAndColumn + ) + end createFakeNewMethod - createBlock(iterAssignNode :: controlStructureNode :: Nil, lineAndColumn) - end createForLowering + def convert(ret: ast.Return): NewNode = + createReturn(ret.value.map(convert), Some(nodeToCode.getCode(ret)), lineAndColOf(ret)) - def convert(astWhile: ast.While): nodes.NewNode = - val conditionNode = convert(astWhile.test) - val bodyStmtNodes = astWhile.body.map(convert) + def convert(delete: ast.Delete): NewNode = + val deleteArgs = delete.targets.map(convert) - val controlStructureNode = - nodeBuilder.controlStructureNode( - "while ... : ...", - ControlStructureTypes.WHILE, - lineAndColOf(astWhile) - ) - edgeBuilder.conditionEdge(conditionNode, controlStructureNode) + val code = "del " + deleteArgs.map(codeOf).mkString(", ") + val callNode = nodeBuilder.callNode( + code, + ".delete", + DispatchTypes.STATIC_DISPATCH, + lineAndColOf(delete) + ) - val bodyBlockNode = createBlock(bodyStmtNodes, lineAndColOf(astWhile)) - addAstChildNodes(controlStructureNode, 1, conditionNode, bodyBlockNode) + addAstChildrenAsArguments(callNode, 1, deleteArgs) + callNode - if astWhile.orelse.nonEmpty then - val elseStmtNodes = astWhile.orelse.map(convert) - val elseBlockNode = - createBlock(elseStmtNodes, lineAndColOf(astWhile.orelse.head)) - addAstChildNodes(controlStructureNode, 3, elseBlockNode) + def convert(assign: ast.Assign): nodes.NewNode = + val loweredNodes = + createValueToTargetsDecomposition( + assign.targets, + convert(assign.value), + lineAndColOf(assign) + ) - controlStructureNode - end convert + if loweredNodes.size == 1 then + // Simple assignment can be returned directly. + loweredNodes.head + else + createBlock(loweredNodes, lineAndColOf(assign)) + + // TODO for now we ignore the annotation part and just emit the pure + // assignment. + def convert(annotatedAssign: ast.AnnAssign): NewNode = + val targetNode = convert(annotatedAssign.target) + + annotatedAssign.value match + case Some(value) => + val valueNode = convert(value) + createAssignment(targetNode, valueNode, lineAndColOf(annotatedAssign)) + case None => + // If there is no value, this is just an expr: annotation and since + // we for now ignore the annotation we emit just the expr because + // it may have side effects. + targetNode + + def convert(augAssign: ast.AugAssign): NewNode = + val targetNode = convert(augAssign.target) + val valueNode = convert(augAssign.value) + + val (operatorCode, operatorFullName) = + augAssign.op match + case ast.Add => ("+=", Operators.assignmentPlus) + case ast.Sub => ("-=", Operators.assignmentMinus) + case ast.Mult => ("*=", Operators.assignmentMultiplication) + case ast.MatMult => + ( + "@=", + ".assignmentMatMult" + ) // TODO make this a define and add policy for this + case ast.Div => ("/=", Operators.assignmentDivision) + case ast.Mod => ("%=", Operators.assignmentModulo) + case ast.Pow => ("**=", Operators.assignmentExponentiation) + case ast.LShift => ("<<=", Operators.assignmentShiftLeft) + case ast.RShift => ("<<=", Operators.assignmentArithmeticShiftRight) + case ast.BitOr => ("|=", Operators.assignmentOr) + case ast.BitXor => ("^=", Operators.assignmentXor) + case ast.BitAnd => ("&=", Operators.assignmentAnd) + case ast.FloorDiv => + ( + "//=", + ".assignmentFloorDiv" + ) // TODO make this a define and add policy for this + + createAugAssignment( + targetNode, + operatorCode, + valueNode, + operatorFullName, + lineAndColOf(augAssign) + ) + end convert + + // TODO write test + def convert(forStmt: ast.For): NewNode = + createForLowering( + forStmt.target, + forStmt.iter, + Iterable.empty, + forStmt.body.map(convert), + forStmt.orelse.map(convert), + isAsync = false, + lineAndColOf(forStmt) + ) + + def convert(forStmt: ast.AsyncFor): NewNode = + createForLowering( + forStmt.target, + forStmt.iter, + Iterable.empty, + forStmt.body.map(convert), + forStmt.orelse.map(convert), + isAsync = true, + lineAndColOf(forStmt) + ) + + // Lowering of for x in y: : + // { + // iterator = y.__iter__() + // while (UNKNOWN condition): + // (x = iterator.__next__()) + // + // } + // If one "if" is present the lowering of for x in y if z: + // { + // iterator = y.__iter__() + // while (UNKNOWN condition): + // if (!z): continue + // (x = iterator.__next__()) + // + // } + // If multiple "ifs" are present the lowering of for x in y if z if a: ..,: + // { + // iterator = y.__iter__() + // while (UNKNOWN condition): + // if (!(z and a)): continue + // (x = iterator.__next__()) + // + // } + protected def createForLowering( + target: ast.iexpr, + iter: ast.iexpr, + ifs: Iterable[ast.iexpr], + bodyNodes: Iterable[nodes.NewNode], + orelseNodes: Iterable[nodes.NewNode], + isAsync: Boolean, + lineAndColumn: LineAndColumn + ): nodes.NewNode = + val iterVariableName = getUnusedName() + val iterExprIterCallNode = + createXDotYCall( + () => convert(iter), + "__iter__", + xMayHaveSideEffects = !iter.isInstanceOf[ast.Name], + lineAndColumn, + Nil, + Nil + ) + val iterAssignNode = + createAssignmentToIdentifier(iterVariableName, iterExprIterCallNode, lineAndColumn) - def convert(astIf: ast.If): nodes.NewNode = - val conditionNode = convert(astIf.test) - val bodyStmtNodes = astIf.body.map(convert) + val conditionNode = + nodeBuilder.unknownNode("iteratorNonEmptyOrException", "", lineAndColumn) - val controlStructureNode = - nodeBuilder.controlStructureNode( - "if ... : ...", - ControlStructureTypes.IF, - lineAndColOf(astIf) - ) - edgeBuilder.conditionEdge(conditionNode, controlStructureNode) - - val bodyBlockNode = createBlock(bodyStmtNodes, lineAndColOf(astIf)) - addAstChildNodes(controlStructureNode, 1, conditionNode, bodyBlockNode) - - if astIf.orelse.nonEmpty then - val elseStmtNodes = astIf.orelse.map(convert) - val elseBlockNode = createBlock(elseStmtNodes, lineAndColOf(astIf.orelse.head)) - addAstChildNodes(controlStructureNode, 3, elseBlockNode) - - controlStructureNode - end convert - - def convert(withStmt: ast.With): NewNode = - val loweredNodes = - withStmt.items.foldRight(withStmt.body.map(convert)) { case (withItem, bodyStmts) => - mutable.ArrayBuffer.empty.append(convertWithItem(withItem, bodyStmts)) - } - - loweredNodes.head - - def convert(withStmt: ast.AsyncWith): NewNode = - val loweredNodes = - withStmt.items.foldRight(withStmt.body.map(convert)) { case (withItem, bodyStmts) => - mutable.ArrayBuffer.empty.append(convertWithItem(withItem, bodyStmts)) - } - - loweredNodes.head - - // Handles the lowering of a single "with item". E.g. for: with A() as a, B() as b - // "A() as a" is a single "with item". - // The lowering for: - // with EXPRESSION as TARGET: - // SUITE - // is: - // manager = (EXPRESSION) - // enter = manager.__enter__ - // exit = manager.__exit__ - // value = enter(manager) - // - // try: - // TARGET = value - // SUITE - // finally: - // exit(manager) - // - // Note that this is not quite semantically correct because we ignore the exit method - // arguments and the exit call return value. This is fine for us because for our data - // flow tracking purposed we dont need that extra information and the AST is anyway - // "broken" because we are heavily lowering. - // For reference the following excerpt taken from the official Python 3.9.5 - // documentation found at - // https://docs.python.org/3/reference/compound_stmts.html#the-with-statement - // shows the semantically correct lowering: - // manager = (EXPRESSION) - // enter = type(manager).__enter__ - // exit = type(manager).__exit__ - // value = enter(manager) - // hit_except = False - // - // try: - // TARGET = value - // SUITE - // except: - // hit_except = True - // if not exit(manager, *sys.exc_info()): - // raise - // finally: - // if not hit_except: - // exit(manager, None, None, None) - private def convertWithItem( - withItem: ast.Withitem, - suite: collection.Seq[nodes.NewNode] - ): nodes.NewNode = - val lineAndCol = lineAndColOf(withItem.context_expr) - val managerIdentifierName = getUnusedName("manager") - - val assignmentToManager = - createAssignmentToIdentifier( - managerIdentifierName, - convert(withItem.context_expr), - lineAndCol - ) + val controlStructureNode = + nodeBuilder.controlStructureNode( + "while ... : ...", + ControlStructureTypes.WHILE, + lineAndColumn + ) + edgeBuilder.conditionEdge(conditionNode, controlStructureNode) - val enterIdentifierName = getUnusedName("enter") - val assignmentToEnter = createAssignmentToIdentifier( - enterIdentifierName, - createFieldAccess( - createIdentifierNode(managerIdentifierName, Load, lineAndCol), - "__enter__", - lineAndCol - ), - lineAndCol + val iterNextCallNode = + createXDotYCall( + () => createIdentifierNode(iterVariableName, Load, lineAndColumn), + "__next__", + xMayHaveSideEffects = false, + lineAndColumn, + Nil, + Nil ) - val exitIdentifierName = getUnusedName("exit") - val assignmentToExit = createAssignmentToIdentifier( - exitIdentifierName, - createFieldAccess( - createIdentifierNode(managerIdentifierName, Load, lineAndCol), - "__exit__", - lineAndCol - ), - lineAndCol + val loweredAssignNodes = + createValueToTargetsDecomposition( + Iterable.single(target), + iterNextCallNode, + lineAndColumn ) - val valueIdentifierName = getUnusedName("value") - val assignmentToValue = createAssignmentToIdentifier( - valueIdentifierName, - createInstanceCall( - createIdentifierNode(enterIdentifierName, Load, lineAndCol), - createIdentifierNode(managerIdentifierName, Load, lineAndCol), - "", - lineAndCol, - Nil, - Nil + val blockStmtNodes = mutable.ArrayBuffer.empty[nodes.NewNode] + blockStmtNodes.appendAll(loweredAssignNodes) + + if ifs.nonEmpty then + val conditionNode = + if ifs.size == 1 then + ifs.head + else + ast.BoolOp(ast.And, ifs.to(mutable.Seq), ifs.head.attributeProvider) + val ifNotContinueNode = convert( + ast.If( + ast.UnaryOp(ast.Not, conditionNode, ifs.head.attributeProvider), + mutable.ArrayBuffer.empty[ast.istmt].append( + ast.Continue(ifs.head.attributeProvider) ), - lineAndCol + mutable.Seq.empty[ast.istmt], + ifs.head.attributeProvider ) + ) - val tryBody = - withItem.optional_vars match - case Some(optionalVar) => - val loweredTargetAssignNodes = createValueToTargetsDecomposition( - withItem.optional_vars, - createIdentifierNode(valueIdentifierName, Load, lineAndCol), - lineAndCol - ) - - loweredTargetAssignNodes ++ suite - case None => - suite - - // TODO For the except handler we currently lower as: - // hit_except = True - // exit(manager) - // instead of: - // hit_except = True - // if not exit(manager, *sys.exc_info()): - // raise - - val finalBlockStmts = - createInstanceCall( - createIdentifierNode("__exit__", Load, lineAndCol), - createIdentifierNode(managerIdentifierName, Load, lineAndCol), - "", - lineAndCol, - Nil, - Nil - ) :: Nil - - val tryBlock = createTry(tryBody, Nil, finalBlockStmts, Nil, lineAndCol) - - val blockStmts = mutable.ArrayBuffer.empty[nodes.NewNode] - blockStmts.append(assignmentToManager) - blockStmts.append(assignmentToEnter) - blockStmts.append(assignmentToExit) - blockStmts.append(assignmentToValue) - blockStmts.append(tryBlock) - - createBlock(blockStmts, lineAndCol) - end convertWithItem - - // TODO add case pattern and guard statements to cpg - def convert(matchStmt: ast.Match): NewNode = - val controlStructureNode = - nodeBuilder.controlStructureNode( - "match ... : ...", - ControlStructureTypes.SWITCH, - lineAndColOf(matchStmt) - ) + blockStmtNodes.append(ifNotContinueNode) + bodyNodes.foreach(blockStmtNodes.append) - val matchSubject = convert(matchStmt.subject) - - val caseBlocks = matchStmt.cases.map { caseStmt => - val bodyNodes = caseStmt.body.map(convert) - createBlock(bodyNodes, lineAndColOf(caseStmt.pattern)) - } + val bodyBlockNode = createBlock(blockStmtNodes, lineAndColumn) + addAstChildNodes(controlStructureNode, 1, conditionNode, bodyBlockNode) - edgeBuilder.conditionEdge(matchSubject, controlStructureNode) - addAstChildNodes(controlStructureNode, 1, matchSubject) - addAstChildNodes(controlStructureNode, 2, caseBlocks) + if orelseNodes.nonEmpty then + val elseBlockNode = createBlock(orelseNodes, lineAndColumn) + addAstChildNodes(controlStructureNode, 3, elseBlockNode) - controlStructureNode - end convert + createBlock(iterAssignNode :: controlStructureNode :: Nil, lineAndColumn) + end createForLowering - def convert(raise: ast.Raise): NewNode = - val excNodeOption = raise.exc.map(convert) - val causeNodeOption = raise.cause.map(convert) + def convert(astWhile: ast.While): nodes.NewNode = + val conditionNode = convert(astWhile.test) + val bodyStmtNodes = astWhile.body.map(convert) - val args = mutable.ArrayBuffer.empty[nodes.NewNode] - args.appendAll(excNodeOption) - args.appendAll(causeNodeOption) - - val code = "raise" + - excNodeOption.map(excNode => " " + codeOf(excNode)).getOrElse("") + - causeNodeOption.map(causeNode => " from " + codeOf(causeNode)).getOrElse("") - - val callNode = nodeBuilder.callNode( - code, - ".raise", - DispatchTypes.STATIC_DISPATCH, - lineAndColOf(raise) + val controlStructureNode = + nodeBuilder.controlStructureNode( + "while ... : ...", + ControlStructureTypes.WHILE, + lineAndColOf(astWhile) ) + edgeBuilder.conditionEdge(conditionNode, controlStructureNode) - addAstChildrenAsArguments(callNode, 1, args) + val bodyBlockNode = createBlock(bodyStmtNodes, lineAndColOf(astWhile)) + addAstChildNodes(controlStructureNode, 1, conditionNode, bodyBlockNode) - callNode - end convert + if astWhile.orelse.nonEmpty then + val elseStmtNodes = astWhile.orelse.map(convert) + val elseBlockNode = + createBlock(elseStmtNodes, lineAndColOf(astWhile.orelse.head)) + addAstChildNodes(controlStructureNode, 3, elseBlockNode) - def convert(tryStmt: ast.Try): NewNode = - createTry( - tryStmt.body.map(convert), - tryStmt.handlers.map(convert), - tryStmt.finalbody.map(convert), - tryStmt.orelse.map(convert), - lineAndColOf(tryStmt) - ) + controlStructureNode + end convert - def convert(assert: ast.Assert): NewNode = - val testNode = convert(assert.test) - val msgNode = assert.msg.map(convert) + def convert(astIf: ast.If): nodes.NewNode = + val conditionNode = convert(astIf.test) + val bodyStmtNodes = astIf.body.map(convert) - val code = "assert " + codeOf(testNode) + msgNode.map(m => ", " + codeOf(m)).getOrElse("") - val callNode = nodeBuilder.callNode( - code, - ".assert", - DispatchTypes.STATIC_DISPATCH, - lineAndColOf(assert) + val controlStructureNode = + nodeBuilder.controlStructureNode( + "if ... : ...", + ControlStructureTypes.IF, + lineAndColOf(astIf) ) + edgeBuilder.conditionEdge(conditionNode, controlStructureNode) - addAstChildrenAsArguments(callNode, 1, testNode) - if msgNode.isDefined then - addAstChildrenAsArguments(callNode, 2, msgNode) - callNode - - // Lowering of import x: - // x = import("", "x") - // Lowering of import x as y: - // y = import("", "x") - // Lowering of import x, y: - // { - // x = import("", "x") - // y = import("", "y") - // } - def convert(importStmt: ast.Import): NewNode = - createTransformedImport("", importStmt.names, lineAndColOf(importStmt)) - - // Lowering of from x import y: - // y = import("x", "y") - // Lowering of from x import y as z: - // z = import("x", "y") - // Lowering of from x import y, z: - // { - // y = import("x", "y") - // z = import("x", "z") - // } - def convert(importFrom: ast.ImportFrom): NewNode = - var moduleName = "" - - for i <- 0 until importFrom.level do - moduleName = moduleName.appended('.') - moduleName += importFrom.module.getOrElse("") - - createTransformedImport(moduleName, importFrom.names, lineAndColOf(importFrom)) - - def convert(global: ast.Global): NewNode = - global.names.foreach(contextStack.addGlobalVariable) - val code = global.names.mkString("global ", ", ", "") - nodeBuilder.unknownNode(code, global.getClass.getName, lineAndColOf(global)) - - def convert(nonLocal: ast.Nonlocal): NewNode = - nonLocal.names.foreach(contextStack.addNonLocalVariable) - val code = nonLocal.names.mkString("nonlocal ", ", ", "") - nodeBuilder.unknownNode(code, nonLocal.getClass.getName, lineAndColOf(nonLocal)) - - def convert(expr: ast.Expr): nodes.NewNode = - convert(expr.value) - - def convert(pass: ast.Pass): nodes.NewNode = - nodeBuilder.callNode( - "pass", - ".pass", - DispatchTypes.STATIC_DISPATCH, - lineAndColOf(pass) - ) + val bodyBlockNode = createBlock(bodyStmtNodes, lineAndColOf(astIf)) + addAstChildNodes(controlStructureNode, 1, conditionNode, bodyBlockNode) - def convert(astBreak: ast.Break): nodes.NewNode = - nodeBuilder.controlStructureNode( - "break", - ControlStructureTypes.BREAK, - lineAndColOf(astBreak) - ) + if astIf.orelse.nonEmpty then + val elseStmtNodes = astIf.orelse.map(convert) + val elseBlockNode = createBlock(elseStmtNodes, lineAndColOf(astIf.orelse.head)) + addAstChildNodes(controlStructureNode, 3, elseBlockNode) - def convert(astContinue: ast.Continue): nodes.NewNode = - nodeBuilder.controlStructureNode( - "continue", - ControlStructureTypes.CONTINUE, - lineAndColOf(astContinue) - ) + controlStructureNode + end convert - def convert(raise: ast.RaiseP2): NewNode = ??? + def convert(withStmt: ast.With): NewNode = + val loweredNodes = + withStmt.items.foldRight(withStmt.body.map(convert)) { case (withItem, bodyStmts) => + mutable.ArrayBuffer.empty.append(convertWithItem(withItem, bodyStmts)) + } - def convert(errorStatement: ast.ErrorStatement): NewNode = - nodeBuilder.unknownNode( - errorStatement.toString, - errorStatement.getClass.getName, - lineAndColOf(errorStatement) - ) + loweredNodes.head - def convert(expr: ast.iexpr): NewNode = - expr match - case node: ast.BoolOp => convert(node) - case node: ast.NamedExpr => convert(node) - case node: ast.BinOp => convert(node) - case node: ast.UnaryOp => convert(node) - case node: ast.Lambda => convert(node) - case node: ast.IfExp => convert(node) - case node: ast.Dict => convert(node) - case node: ast.Set => convert(node) - case node: ast.ListComp => convert(node) - case node: ast.SetComp => convert(node) - case node: ast.DictComp => convert(node) - case node: ast.GeneratorExp => convert(node) - case node: ast.Await => convert(node) - case node: ast.Yield => unhandled(node) - case node: ast.YieldFrom => unhandled(node) - case node: ast.Compare => convert(node) - case node: ast.Call => convert(node) - case node: ast.FormattedValue => convert(node) - case node: ast.JoinedString => convert(node) - case node: ast.Constant => convert(node) - case node: ast.Attribute => convert(node) - case node: ast.Subscript => convert(node) - case node: ast.Starred => convert(node) - case node: ast.Name => convert(node) - case node: ast.List => convert(node) - case node: ast.Tuple => convert(node) - case node: ast.Slice => unhandled(node) - case node: ast.StringExpList => convert(node) - - def convert(boolOp: ast.BoolOp): nodes.NewNode = - def boolOpToCodeAndFullName(operator: ast.iboolop): () => (String, String) = () => - operator match - case ast.And => ("and", Operators.logicalAnd) - case ast.Or => ("or", Operators.logicalOr) - - val operandNodes = boolOp.values.map(convert) - createNAryOperatorCall( - boolOpToCodeAndFullName(boolOp.op), - operandNodes, - lineAndColOf(boolOp) - ) + def convert(withStmt: ast.AsyncWith): NewNode = + val loweredNodes = + withStmt.items.foldRight(withStmt.body.map(convert)) { case (withItem, bodyStmts) => + mutable.ArrayBuffer.empty.append(convertWithItem(withItem, bodyStmts)) + } - // TODO test - def convert(namedExpr: ast.NamedExpr): NewNode = - val targetNode = convert(namedExpr.target) - val valueNode = convert(namedExpr.value) - - createAssignment(targetNode, valueNode, lineAndColOf(namedExpr)) - - def convert(binOp: ast.BinOp): nodes.NewNode = - val lhsNode = convert(binOp.left) - val rhsNode = convert(binOp.right) - - val opCodeAndFullName = - binOp.op match - case ast.Add => ("+", Operators.addition) - case ast.Sub => ("-", Operators.subtraction) - case ast.Mult => ("*", Operators.multiplication) - case ast.MatMult => - ("@", ".matMult") // TODO make this a define and add policy for this - case ast.Div => ("/", Operators.division) - case ast.Mod => ("%", Operators.modulo) - case ast.Pow => ("**", Operators.exponentiation) - case ast.LShift => ("<<", Operators.shiftLeft) - case ast.RShift => (">>", Operators.arithmeticShiftRight) - case ast.BitOr => ("|", Operators.or) - case ast.BitXor => ("^", Operators.xor) - case ast.BitAnd => ("&", Operators.and) - case ast.FloorDiv => - ("//", ".floorDiv") // TODO make this a define and add policy for this - - createBinaryOperatorCall(lhsNode, () => opCodeAndFullName, rhsNode, lineAndColOf(binOp)) - end convert - - def convert(unaryOp: ast.UnaryOp): nodes.NewNode = - val operandNode = convert(unaryOp.operand) - - val (operatorCode, methodFullName) = - unaryOp.op match - case ast.Invert => ("~", Operators.not) - case ast.Not => ("not ", Operators.logicalNot) - case ast.UAdd => ("+", Operators.plus) - case ast.USub => ("-", Operators.minus) - - val code = operatorCode + codeOf(operandNode) - val callNode = nodeBuilder.callNode( - code, - methodFullName, - DispatchTypes.STATIC_DISPATCH, - lineAndColOf(unaryOp) + loweredNodes.head + + // Handles the lowering of a single "with item". E.g. for: with A() as a, B() as b + // "A() as a" is a single "with item". + // The lowering for: + // with EXPRESSION as TARGET: + // SUITE + // is: + // manager = (EXPRESSION) + // enter = manager.__enter__ + // exit = manager.__exit__ + // value = enter(manager) + // + // try: + // TARGET = value + // SUITE + // finally: + // exit(manager) + // + // Note that this is not quite semantically correct because we ignore the exit method + // arguments and the exit call return value. This is fine for us because for our data + // flow tracking purposed we dont need that extra information and the AST is anyway + // "broken" because we are heavily lowering. + // For reference the following excerpt taken from the official Python 3.9.5 + // documentation found at + // https://docs.python.org/3/reference/compound_stmts.html#the-with-statement + // shows the semantically correct lowering: + // manager = (EXPRESSION) + // enter = type(manager).__enter__ + // exit = type(manager).__exit__ + // value = enter(manager) + // hit_except = False + // + // try: + // TARGET = value + // SUITE + // except: + // hit_except = True + // if not exit(manager, *sys.exc_info()): + // raise + // finally: + // if not hit_except: + // exit(manager, None, None, None) + private def convertWithItem( + withItem: ast.Withitem, + suite: collection.Seq[nodes.NewNode] + ): nodes.NewNode = + val lineAndCol = lineAndColOf(withItem.context_expr) + val managerIdentifierName = getUnusedName("manager") + + val assignmentToManager = + createAssignmentToIdentifier( + managerIdentifierName, + convert(withItem.context_expr), + lineAndCol ) - addAstChildrenAsArguments(callNode, 1, operandNode) - - callNode - end convert - - def convert(lambda: ast.Lambda): NewNode = - // TODO test lambda expression. - val lambdaCounter = contextStack.getAndIncLambdaCounter() - val lambdaNumberSuffix = - if lambdaCounter == 0 then - "" - else - lambdaCounter.toString - - val name = "" + lambdaNumberSuffix - val (_, methodRefNode) = createMethodAndMethodRef( - name, - Some(name), - createParameterProcessingFunction(lambda.args, isStatic = false), - () => Iterable.single(convert(new ast.Return(lambda.body, lambda.attributeProvider))), - returns = None, - isAsync = false, - lineAndColOf(lambda) - ) - methodRefNode - end convert + val enterIdentifierName = getUnusedName("enter") + val assignmentToEnter = createAssignmentToIdentifier( + enterIdentifierName, + createFieldAccess( + createIdentifierNode(managerIdentifierName, Load, lineAndCol), + "__enter__", + lineAndCol + ), + lineAndCol + ) - // TODO test - def convert(ifExp: ast.IfExp): NewNode = - val bodyNode = convert(ifExp.body) - val testNode = convert(ifExp.test) - val orElseNode = convert(ifExp.orelse) + val exitIdentifierName = getUnusedName("exit") + val assignmentToExit = createAssignmentToIdentifier( + exitIdentifierName, + createFieldAccess( + createIdentifierNode(managerIdentifierName, Load, lineAndCol), + "__exit__", + lineAndCol + ), + lineAndCol + ) - val code = codeOf(bodyNode) + " if " + codeOf(testNode) + " else " + codeOf(orElseNode) - val callNode = nodeBuilder.callNode( - code, - Operators.conditional, - DispatchTypes.STATIC_DISPATCH, - lineAndColOf(ifExp) - ) + val valueIdentifierName = getUnusedName("value") + val assignmentToValue = createAssignmentToIdentifier( + valueIdentifierName, + createInstanceCall( + createIdentifierNode(enterIdentifierName, Load, lineAndCol), + createIdentifierNode(managerIdentifierName, Load, lineAndCol), + "", + lineAndCol, + Nil, + Nil + ), + lineAndCol + ) - // testNode is first argument to match semantics of Operators.conditional. - addAstChildrenAsArguments(callNode, 1, testNode, bodyNode, orElseNode) - - callNode - - /** Lowering of {x:1, y:2, **z}: { tmp = {} tmp[x] = 1 tmp[y] = 2 tmp.update(z) tmp } - */ - // TODO test - def convert(dict: ast.Dict): NewNode = - val MAX_KV_PAIRS = 100 - val tmpVariableName = getUnusedName() - val dictOperatorCall = - createLiteralOperatorCall("{", "}", ".dictLiteral", lineAndColOf(dict)) - val dictVariableAssigNode = - createAssignmentToIdentifier(tmpVariableName, dictOperatorCall, lineAndColOf(dict)) - - val dictElementAssignNodes = if dict.keys.size > MAX_KV_PAIRS then - Seq( - nodeBuilder - .callNode( - "", - Constants.ANY, - DispatchTypes.STATIC_DISPATCH, - lineAndColOf(dict) - ) - ) - else - dict.keys.zip(dict.values).map { case (key, value) => - key match - case Some(key) => - val indexAccessNode = createIndexAccess( - createIdentifierNode(tmpVariableName, Load, lineAndColOf(dict)), - convert(key), - lineAndColOf(dict) - ) - - createAssignment(indexAccessNode, convert(value), lineAndColOf(dict)) - case None => - createXDotYCall( - () => createIdentifierNode(tmpVariableName, Load, lineAndColOf(dict)), - "update", - xMayHaveSideEffects = false, - lineAndColOf(dict), - convert(value) :: Nil, - Nil - ) - } - - val dictInstanceReturnIdentifierNode = - createIdentifierNode(tmpVariableName, Load, lineAndColOf(dict)) - - val blockElements = mutable.ArrayBuffer.empty[nodes.NewNode] - blockElements.append(dictVariableAssigNode) - blockElements.appendAll(dictElementAssignNodes) - blockElements.append(dictInstanceReturnIdentifierNode) - createBlock(blockElements, lineAndColOf(dict)) - end convert - - // TODO test - def convert(set: ast.Set): nodes.NewNode = - val setElementNodes = set.elts.map(convert) - val code = setElementNodes.map(codeOf).mkString("{", ", ", "}") - - val callNode = nodeBuilder.callNode( - code, - ".setLiteral", - DispatchTypes.STATIC_DISPATCH, - lineAndColOf(set) - ) + val tryBody = + withItem.optional_vars match + case Some(optionalVar) => + val loweredTargetAssignNodes = createValueToTargetsDecomposition( + withItem.optional_vars, + createIdentifierNode(valueIdentifierName, Load, lineAndCol), + lineAndCol + ) - addAstChildrenAsArguments(callNode, 1, setElementNodes) + loweredTargetAssignNodes ++ suite + case None => + suite + + // TODO For the except handler we currently lower as: + // hit_except = True + // exit(manager) + // instead of: + // hit_except = True + // if not exit(manager, *sys.exc_info()): + // raise + + val finalBlockStmts = + createInstanceCall( + createIdentifierNode("__exit__", Load, lineAndCol), + createIdentifierNode(managerIdentifierName, Load, lineAndCol), + "", + lineAndCol, + Nil, + Nil + ) :: Nil - callNode + val tryBlock = createTry(tryBody, Nil, finalBlockStmts, Nil, lineAndCol) - /** Lowering of [x for y in l for x in y]: { tmp = [] ( for y in l: for x in y: - * tmp.append(x) ) tmp } - */ - // TODO test - def convert(listComp: ast.ListComp): NewNode = - contextStack.pushSpecialContext() - val tmpVariableName = getUnusedName() + val blockStmts = mutable.ArrayBuffer.empty[nodes.NewNode] + blockStmts.append(assignmentToManager) + blockStmts.append(assignmentToEnter) + blockStmts.append(assignmentToExit) + blockStmts.append(assignmentToValue) + blockStmts.append(tryBlock) - // Create tmp = list() - val listOperatorCall = - createLiteralOperatorCall("[", "]", ".listLiteral", lineAndColOf(listComp)) - val variableAssignNode = - createAssignmentToIdentifier(tmpVariableName, listOperatorCall, lineAndColOf(listComp)) + createBlock(blockStmts, lineAndCol) + end convertWithItem - // Create tmp.append(x) - val listVarAppendCallNode = createXDotYCall( - () => createIdentifierNode(tmpVariableName, Load, lineAndColOf(listComp)), - "append", - xMayHaveSideEffects = false, - lineAndColOf(listComp), - convert(listComp.elt) :: Nil, - Nil + // TODO add case pattern and guard statements to cpg + def convert(matchStmt: ast.Match): NewNode = + val controlStructureNode = + nodeBuilder.controlStructureNode( + "match ... : ...", + ControlStructureTypes.SWITCH, + lineAndColOf(matchStmt) ) - val comprehensionBlockNode = createComprehensionLowering( - tmpVariableName, - variableAssignNode, - listVarAppendCallNode, - listComp.generators, - lineAndColOf(listComp) - ) + val matchSubject = convert(matchStmt.subject) - contextStack.pop() + val caseBlocks = matchStmt.cases.map { caseStmt => + val bodyNodes = caseStmt.body.map(convert) + createBlock(bodyNodes, lineAndColOf(caseStmt.pattern)) + } - comprehensionBlockNode - end convert + edgeBuilder.conditionEdge(matchSubject, controlStructureNode) + addAstChildNodes(controlStructureNode, 1, matchSubject) + addAstChildNodes(controlStructureNode, 2, caseBlocks) - /** Lowering of {x for y in l for x in y}: { tmp = {} ( for y in l: for x in y: - * tmp.add(x) ) tmp } - */ - // TODO test - def convert(setComp: ast.SetComp): NewNode = - contextStack.pushSpecialContext() - val tmpVariableName = getUnusedName() + controlStructureNode + end convert - val setOperatorCall = - createLiteralOperatorCall("{", "}", ".setLiteral", lineAndColOf(setComp)) - val variableAssignNode = - createAssignmentToIdentifier(tmpVariableName, setOperatorCall, lineAndColOf(setComp)) + def convert(raise: ast.Raise): NewNode = + val excNodeOption = raise.exc.map(convert) + val causeNodeOption = raise.cause.map(convert) - // Create tmp.add(x) - val setVarAddCallNode = createXDotYCall( - () => createIdentifierNode(tmpVariableName, Load, lineAndColOf(setComp)), - "add", - xMayHaveSideEffects = false, - lineAndColOf(setComp), - convert(setComp.elt) :: Nil, - Nil - ) + val args = mutable.ArrayBuffer.empty[nodes.NewNode] + args.appendAll(excNodeOption) + args.appendAll(causeNodeOption) - val comprehensionBlockNode = createComprehensionLowering( - tmpVariableName, - variableAssignNode, - setVarAddCallNode, - setComp.generators, - lineAndColOf(setComp) - ) + val code = "raise" + + excNodeOption.map(excNode => " " + codeOf(excNode)).getOrElse("") + + causeNodeOption.map(causeNode => " from " + codeOf(causeNode)).getOrElse("") - contextStack.pop() - - comprehensionBlockNode - end convert - - /** Lowering of {k:v for y in l for k, v in y}: { tmp = {} ( for y in l: for k, v in - * y: tmp[k] = v ) tmp } - */ - // TODO test - def convert(dictComp: ast.DictComp): NewNode = - contextStack.pushSpecialContext() - val tmpVariableName = getUnusedName() - - val dictOperatorCall = - createLiteralOperatorCall("{", "}", ".dictLiteral", lineAndColOf(dictComp)) - val variableAssignNode = - createAssignmentToIdentifier(tmpVariableName, dictOperatorCall, lineAndColOf(dictComp)) - - // Create tmp[k] = v - val dictAssigNode = createAssignment( - createIndexAccess( - createIdentifierNode(tmpVariableName, Load, lineAndColOf(dictComp)), - convert(dictComp.key), - lineAndColOf(dictComp) - ), - convert(dictComp.value), - lineAndColOf(dictComp) - ) + val callNode = nodeBuilder.callNode( + code, + ".raise", + DispatchTypes.STATIC_DISPATCH, + lineAndColOf(raise) + ) - val comprehensionBlockNode = createComprehensionLowering( - tmpVariableName, - variableAssignNode, - dictAssigNode, - dictComp.generators, - lineAndColOf(dictComp) - ) + addAstChildrenAsArguments(callNode, 1, args) + + callNode + end convert + + def convert(tryStmt: ast.Try): NewNode = + createTry( + tryStmt.body.map(convert), + tryStmt.handlers.map(convert), + tryStmt.finalbody.map(convert), + tryStmt.orelse.map(convert), + lineAndColOf(tryStmt) + ) + + def convert(assert: ast.Assert): NewNode = + val testNode = convert(assert.test) + val msgNode = assert.msg.map(convert) + + val code = "assert " + codeOf(testNode) + msgNode.map(m => ", " + codeOf(m)).getOrElse("") + val callNode = nodeBuilder.callNode( + code, + ".assert", + DispatchTypes.STATIC_DISPATCH, + lineAndColOf(assert) + ) - contextStack.pop() - - comprehensionBlockNode - end convert - - /** Lowering of (x for y in l for x in y): { tmp = .genExp ( for y in l: - * for x in y: tmp.append(x) ) tmp } This lowering is not quite correct as it ignores the lazy - * evaluation of the generator expression. Instead it just mimics the list comprehension - * lowering but for now this is good enough. - */ - // TODO test - def convert(generatorExp: ast.GeneratorExp): NewNode = - contextStack.pushSpecialContext() - val tmpVariableName = getUnusedName() - - // Create tmp = list() - val genExpOperatorCall = - nodeBuilder.callNode( - ".genExp", - ".genExp", - DispatchTypes.STATIC_DISPATCH, - lineAndColOf(generatorExp) - ) + addAstChildrenAsArguments(callNode, 1, testNode) + if msgNode.isDefined then + addAstChildrenAsArguments(callNode, 2, msgNode) + callNode + + // Lowering of import x: + // x = import("", "x") + // Lowering of import x as y: + // y = import("", "x") + // Lowering of import x, y: + // { + // x = import("", "x") + // y = import("", "y") + // } + def convert(importStmt: ast.Import): NewNode = + createTransformedImport("", importStmt.names, lineAndColOf(importStmt)) + + // Lowering of from x import y: + // y = import("x", "y") + // Lowering of from x import y as z: + // z = import("x", "y") + // Lowering of from x import y, z: + // { + // y = import("x", "y") + // z = import("x", "z") + // } + def convert(importFrom: ast.ImportFrom): NewNode = + var moduleName = "" + + for i <- 0 until importFrom.level do + moduleName = moduleName.appended('.') + moduleName += importFrom.module.getOrElse("") + + createTransformedImport(moduleName, importFrom.names, lineAndColOf(importFrom)) + + def convert(global: ast.Global): NewNode = + global.names.foreach(contextStack.addGlobalVariable) + val code = global.names.mkString("global ", ", ", "") + nodeBuilder.unknownNode(code, global.getClass.getName, lineAndColOf(global)) + + def convert(nonLocal: ast.Nonlocal): NewNode = + nonLocal.names.foreach(contextStack.addNonLocalVariable) + val code = nonLocal.names.mkString("nonlocal ", ", ", "") + nodeBuilder.unknownNode(code, nonLocal.getClass.getName, lineAndColOf(nonLocal)) + + def convert(expr: ast.Expr): nodes.NewNode = + convert(expr.value) + + def convert(pass: ast.Pass): nodes.NewNode = + nodeBuilder.callNode( + "pass", + ".pass", + DispatchTypes.STATIC_DISPATCH, + lineAndColOf(pass) + ) + + def convert(astBreak: ast.Break): nodes.NewNode = + nodeBuilder.controlStructureNode( + "break", + ControlStructureTypes.BREAK, + lineAndColOf(astBreak) + ) + + def convert(astContinue: ast.Continue): nodes.NewNode = + nodeBuilder.controlStructureNode( + "continue", + ControlStructureTypes.CONTINUE, + lineAndColOf(astContinue) + ) + + def convert(raise: ast.RaiseP2): NewNode = ??? + + def convert(errorStatement: ast.ErrorStatement): NewNode = + nodeBuilder.unknownNode( + errorStatement.toString, + errorStatement.getClass.getName, + lineAndColOf(errorStatement) + ) + + def convert(expr: ast.iexpr): NewNode = + expr match + case node: ast.BoolOp => convert(node) + case node: ast.NamedExpr => convert(node) + case node: ast.BinOp => convert(node) + case node: ast.UnaryOp => convert(node) + case node: ast.Lambda => convert(node) + case node: ast.IfExp => convert(node) + case node: ast.Dict => convert(node) + case node: ast.Set => convert(node) + case node: ast.ListComp => convert(node) + case node: ast.SetComp => convert(node) + case node: ast.DictComp => convert(node) + case node: ast.GeneratorExp => convert(node) + case node: ast.Await => convert(node) + case node: ast.Yield => unhandled(node) + case node: ast.YieldFrom => unhandled(node) + case node: ast.Compare => convert(node) + case node: ast.Call => convert(node) + case node: ast.FormattedValue => convert(node) + case node: ast.JoinedString => convert(node) + case node: ast.Constant => convert(node) + case node: ast.Attribute => convert(node) + case node: ast.Subscript => convert(node) + case node: ast.Starred => convert(node) + case node: ast.Name => convert(node) + case node: ast.List => convert(node) + case node: ast.Tuple => convert(node) + case node: ast.Slice => unhandled(node) + case node: ast.StringExpList => convert(node) + + def convert(boolOp: ast.BoolOp): nodes.NewNode = + def boolOpToCodeAndFullName(operator: ast.iboolop): () => (String, String) = () => + operator match + case ast.And => ("and", Operators.logicalAnd) + case ast.Or => ("or", Operators.logicalOr) + + val operandNodes = boolOp.values.map(convert) + createNAryOperatorCall( + boolOpToCodeAndFullName(boolOp.op), + operandNodes, + lineAndColOf(boolOp) + ) - val variableAssignNode = - createAssignmentToIdentifier( - tmpVariableName, - genExpOperatorCall, - lineAndColOf(generatorExp) - ) + // TODO test + def convert(namedExpr: ast.NamedExpr): NewNode = + val targetNode = convert(namedExpr.target) + val valueNode = convert(namedExpr.value) + + createAssignment(targetNode, valueNode, lineAndColOf(namedExpr)) + + def convert(binOp: ast.BinOp): nodes.NewNode = + val lhsNode = convert(binOp.left) + val rhsNode = convert(binOp.right) + + val opCodeAndFullName = + binOp.op match + case ast.Add => ("+", Operators.addition) + case ast.Sub => ("-", Operators.subtraction) + case ast.Mult => ("*", Operators.multiplication) + case ast.MatMult => + ("@", ".matMult") // TODO make this a define and add policy for this + case ast.Div => ("/", Operators.division) + case ast.Mod => ("%", Operators.modulo) + case ast.Pow => ("**", Operators.exponentiation) + case ast.LShift => ("<<", Operators.shiftLeft) + case ast.RShift => (">>", Operators.arithmeticShiftRight) + case ast.BitOr => ("|", Operators.or) + case ast.BitXor => ("^", Operators.xor) + case ast.BitAnd => ("&", Operators.and) + case ast.FloorDiv => + ("//", ".floorDiv") // TODO make this a define and add policy for this + + createBinaryOperatorCall(lhsNode, () => opCodeAndFullName, rhsNode, lineAndColOf(binOp)) + end convert + + def convert(unaryOp: ast.UnaryOp): nodes.NewNode = + val operandNode = convert(unaryOp.operand) + + val (operatorCode, methodFullName) = + unaryOp.op match + case ast.Invert => ("~", Operators.not) + case ast.Not => ("not ", Operators.logicalNot) + case ast.UAdd => ("+", Operators.plus) + case ast.USub => ("-", Operators.minus) + + val code = operatorCode + codeOf(operandNode) + val callNode = nodeBuilder.callNode( + code, + methodFullName, + DispatchTypes.STATIC_DISPATCH, + lineAndColOf(unaryOp) + ) - // Create tmp.append(x) - val genExpAppendCallNode = createXDotYCall( - () => createIdentifierNode(tmpVariableName, Load, lineAndColOf(generatorExp)), - "append", - xMayHaveSideEffects = false, - lineAndColOf(generatorExp), - convert(generatorExp.elt) :: Nil, - Nil - ) + addAstChildrenAsArguments(callNode, 1, operandNode) - val comprehensionBlockNode = createComprehensionLowering( - tmpVariableName, - variableAssignNode, - genExpAppendCallNode, - generatorExp.generators, - lineAndColOf(generatorExp) - ) + callNode + end convert - contextStack.pop() - - comprehensionBlockNode - end convert - - def convert(await: ast.Await): NewNode = - // Since the CPG format does not provide means to model async/await, - // we for now treat it as non existing. - convert(await.value) - - def convert(yieldExpr: ast.Yield): NewNode = ??? - - def convert(yieldFrom: ast.YieldFrom): NewNode = ??? - - // In case of a single compare operation there is no lowering applied. - // So e.g. x < y stay untouched. - // Otherwise the lowering is as follows: - // Src AST: - // x < y < z < a - // Lowering: - // { - // tmp1 = y - // x < tmp1 && { - // tmp2 = z - // tmp1 < tmp2 && { - // tmp2 < a - // } - // } - // } - def convert(compare: ast.Compare): NewNode = - assert(compare.ops.size == compare.comparators.size) - var lhsNode = convert(compare.left) - - val topLevelExprNodes = - lowerComparatorChain(lhsNode, compare.ops, compare.comparators, lineAndColOf(compare)) - if topLevelExprNodes.size > 1 then - createBlock(topLevelExprNodes, lineAndColOf(compare)) - else - topLevelExprNodes.head - - private def compopToOpCodeAndFullName(compareOp: ast.icompop): () => (String, String) = () => - compareOp match - case ast.Eq => ("==", Operators.equals) - case ast.NotEq => ("!=", Operators.notEquals) - case ast.Lt => ("<", Operators.lessThan) - case ast.LtE => ("<=", Operators.lessEqualsThan) - case ast.Gt => (">", Operators.greaterThan) - case ast.GtE => (">=", Operators.greaterEqualsThan) - case ast.Is => ("is", ".is") - case ast.IsNot => ("is not", ".isNot") - case ast.In => ("in", ".in") - case ast.NotIn => ("not in", ".notIn") - - def lowerComparatorChain( - lhsNode: nodes.NewNode, - compOperators: Iterable[ast.icompop], - comparators: Iterable[ast.iexpr], - lineAndColumn: LineAndColumn - ): Iterable[nodes.NewNode] = - val rhsNode = convert(comparators.head) - - if compOperators.size == 1 then - val compareNode = - createBinaryOperatorCall( - lhsNode, - compopToOpCodeAndFullName(compOperators.head), - rhsNode, - lineAndColumn - ) - Iterable.single(compareNode) + def convert(lambda: ast.Lambda): NewNode = + // TODO test lambda expression. + val lambdaCounter = contextStack.getAndIncLambdaCounter() + val lambdaNumberSuffix = + if lambdaCounter == 0 then + "" else - val tmpVariableName = getUnusedName() - val assignmentNode = - createAssignmentToIdentifier(tmpVariableName, rhsNode, lineAndColumn) - - val tmpIdentifierCompare1 = createIdentifierNode(tmpVariableName, Load, lineAndColumn) - val compareNode = createBinaryOperatorCall( - lhsNode, - compopToOpCodeAndFullName(compOperators.head), - tmpIdentifierCompare1, - lineAndColumn - ) - - val tmpIdentifierCompare2 = createIdentifierNode(tmpVariableName, Load, lineAndColumn) - val childNodes = lowerComparatorChain( - tmpIdentifierCompare2, - compOperators.tail, - comparators.tail, - lineAndColumn - ) - - val blockNode = createBlock(childNodes, lineAndColumn) + lambdaCounter.toString + + val name = "" + lambdaNumberSuffix + val (_, methodRefNode) = createMethodAndMethodRef( + name, + Some(name), + createParameterProcessingFunction(lambda.args, isStatic = false), + () => Iterable.single(convert(new ast.Return(lambda.body, lambda.attributeProvider))), + returns = None, + isAsync = false, + lineAndColOf(lambda) + ) + methodRefNode + end convert + + // TODO test + def convert(ifExp: ast.IfExp): NewNode = + val bodyNode = convert(ifExp.body) + val testNode = convert(ifExp.test) + val orElseNode = convert(ifExp.orelse) + + val code = codeOf(bodyNode) + " if " + codeOf(testNode) + " else " + codeOf(orElseNode) + val callNode = nodeBuilder.callNode( + code, + Operators.conditional, + DispatchTypes.STATIC_DISPATCH, + lineAndColOf(ifExp) + ) - Iterable( - assignmentNode, - createBinaryOperatorCall( - compareNode, - andOpCodeAndFullName(), - blockNode, - lineAndColumn - ) + // testNode is first argument to match semantics of Operators.conditional. + addAstChildrenAsArguments(callNode, 1, testNode, bodyNode, orElseNode) + + callNode + + /** Lowering of {x:1, y:2, **z}: { tmp = {} tmp[x] = 1 tmp[y] = 2 tmp.update(z) tmp } + */ + // TODO test + def convert(dict: ast.Dict): NewNode = + val MAX_KV_PAIRS = 100 + val tmpVariableName = getUnusedName() + val dictOperatorCall = + createLiteralOperatorCall("{", "}", ".dictLiteral", lineAndColOf(dict)) + val dictVariableAssigNode = + createAssignmentToIdentifier(tmpVariableName, dictOperatorCall, lineAndColOf(dict)) + + val dictElementAssignNodes = if dict.keys.size > MAX_KV_PAIRS then + Seq( + nodeBuilder + .callNode( + "", + Constants.ANY, + DispatchTypes.STATIC_DISPATCH, + lineAndColOf(dict) ) - end if - end lowerComparatorChain - - private def andOpCodeAndFullName(): () => (String, String) = () => - ("and", Operators.logicalAnd) - - /** TODO For now this function compromises on the correctness of the lowering in order to get - * some data flow tracking going. - * 1. For constructs like x.func() we assume x to be the instance which is passed into func. - * This is not true since the instance method object gets the instance already - * bound/captured during function access. This becomes relevant for constructs like: - * x.func = y.func <- y.func is class method object x.func() In this case the instance - * passed into func is y and not x. We cannot represent this in th CPG and thus stick to - * the assumption that the part before the "." and the bound/captured instance will be the - * same. For reference see: - * https://docs.python.org/3/reference/datamodel.html#the-standard-type-hierarchy search - * for "Instance methods" - */ - def convert(call: ast.Call): nodes.NewNode = - val argumentNodes = call.args.map(convert).toSeq - val keywordArgNodes = call.keywords.flatMap { keyword => - if keyword.arg.isDefined then - Some((keyword.arg.get, convert(keyword.value))) - else - // keyword.arg == None. This is the case for func(**dict) style arguments. - // TODO implement handling for this case. - None - } + ) + else + dict.keys.zip(dict.values).map { case (key, value) => + key match + case Some(key) => + val indexAccessNode = createIndexAccess( + createIdentifierNode(tmpVariableName, Load, lineAndColOf(dict)), + convert(key), + lineAndColOf(dict) + ) - call.func match - case attribute: ast.Attribute => + createAssignment(indexAccessNode, convert(value), lineAndColOf(dict)) + case None => createXDotYCall( - () => convert(attribute.value), - attribute.attr, - xMayHaveSideEffects = !attribute.value.isInstanceOf[ast.Name], - lineAndColOf(call), - argumentNodes, - keywordArgNodes + () => createIdentifierNode(tmpVariableName, Load, lineAndColOf(dict)), + "update", + xMayHaveSideEffects = false, + lineAndColOf(dict), + convert(value) :: Nil, + Nil ) - case _ => - val receiverNode = convert(call.func) - val name = call.func match - case ast.Name(id, _) => id - case _ => "" - createCall(receiverNode, name, lineAndColOf(call), argumentNodes, keywordArgNodes) - end convert - - def convert(formattedValue: ast.FormattedValue): nodes.NewNode = - val valueNode = convert(formattedValue.value) - - val equalSignStr = if formattedValue.equalSign then "=" else "" - val conversionStr = formattedValue.conversion match - case -1 => "" - case 115 => "!s" - case 114 => "!r" - case 97 => "!a" - - val formatSpecStr = formattedValue.format_spec match - case Some(formatSpec) => ":" + formatSpec - case None => "" - - val code = "{" + codeOf(valueNode) + equalSignStr + conversionStr + formatSpecStr + "}" - - val callNode = nodeBuilder.callNode( - code, - ".formattedValue", - DispatchTypes.STATIC_DISPATCH, - lineAndColOf(formattedValue) - ) + } + + val dictInstanceReturnIdentifierNode = + createIdentifierNode(tmpVariableName, Load, lineAndColOf(dict)) + + val blockElements = mutable.ArrayBuffer.empty[nodes.NewNode] + blockElements.append(dictVariableAssigNode) + blockElements.appendAll(dictElementAssignNodes) + blockElements.append(dictInstanceReturnIdentifierNode) + createBlock(blockElements, lineAndColOf(dict)) + end convert + + // TODO test + def convert(set: ast.Set): nodes.NewNode = + val setElementNodes = set.elts.map(convert) + val code = setElementNodes.map(codeOf).mkString("{", ", ", "}") + + val callNode = nodeBuilder.callNode( + code, + ".setLiteral", + DispatchTypes.STATIC_DISPATCH, + lineAndColOf(set) + ) + + addAstChildrenAsArguments(callNode, 1, setElementNodes) + + callNode + + /** Lowering of [x for y in l for x in y]: { tmp = [] ( for y in l: for x in y: + * tmp.append(x) ) tmp } + */ + // TODO test + def convert(listComp: ast.ListComp): NewNode = + contextStack.pushSpecialContext() + val tmpVariableName = getUnusedName() + + // Create tmp = list() + val listOperatorCall = + createLiteralOperatorCall("[", "]", ".listLiteral", lineAndColOf(listComp)) + val variableAssignNode = + createAssignmentToIdentifier(tmpVariableName, listOperatorCall, lineAndColOf(listComp)) + + // Create tmp.append(x) + val listVarAppendCallNode = createXDotYCall( + () => createIdentifierNode(tmpVariableName, Load, lineAndColOf(listComp)), + "append", + xMayHaveSideEffects = false, + lineAndColOf(listComp), + convert(listComp.elt) :: Nil, + Nil + ) + + val comprehensionBlockNode = createComprehensionLowering( + tmpVariableName, + variableAssignNode, + listVarAppendCallNode, + listComp.generators, + lineAndColOf(listComp) + ) - addAstChildrenAsArguments(callNode, 1, valueNode) + contextStack.pop() + + comprehensionBlockNode + end convert + + /** Lowering of {x for y in l for x in y}: { tmp = {} ( for y in l: for x in y: + * tmp.add(x) ) tmp } + */ + // TODO test + def convert(setComp: ast.SetComp): NewNode = + contextStack.pushSpecialContext() + val tmpVariableName = getUnusedName() + + val setOperatorCall = + createLiteralOperatorCall("{", "}", ".setLiteral", lineAndColOf(setComp)) + val variableAssignNode = + createAssignmentToIdentifier(tmpVariableName, setOperatorCall, lineAndColOf(setComp)) + + // Create tmp.add(x) + val setVarAddCallNode = createXDotYCall( + () => createIdentifierNode(tmpVariableName, Load, lineAndColOf(setComp)), + "add", + xMayHaveSideEffects = false, + lineAndColOf(setComp), + convert(setComp.elt) :: Nil, + Nil + ) - callNode - end convert + val comprehensionBlockNode = createComprehensionLowering( + tmpVariableName, + variableAssignNode, + setVarAddCallNode, + setComp.generators, + lineAndColOf(setComp) + ) - def convert(joinedString: ast.JoinedString): nodes.NewNode = - val argumentNodes = joinedString.values.map(convert) + contextStack.pop() + + comprehensionBlockNode + end convert + + /** Lowering of {k:v for y in l for k, v in y}: { tmp = {} ( for y in l: for k, v in + * y: tmp[k] = v ) tmp } + */ + // TODO test + def convert(dictComp: ast.DictComp): NewNode = + contextStack.pushSpecialContext() + val tmpVariableName = getUnusedName() + + val dictOperatorCall = + createLiteralOperatorCall("{", "}", ".dictLiteral", lineAndColOf(dictComp)) + val variableAssignNode = + createAssignmentToIdentifier(tmpVariableName, dictOperatorCall, lineAndColOf(dictComp)) + + // Create tmp[k] = v + val dictAssigNode = createAssignment( + createIndexAccess( + createIdentifierNode(tmpVariableName, Load, lineAndColOf(dictComp)), + convert(dictComp.key), + lineAndColOf(dictComp) + ), + convert(dictComp.value), + lineAndColOf(dictComp) + ) - val code = joinedString.prefix + joinedString.quote + argumentNodes - .map(codeOf) - .mkString("") + joinedString.quote + val comprehensionBlockNode = createComprehensionLowering( + tmpVariableName, + variableAssignNode, + dictAssigNode, + dictComp.generators, + lineAndColOf(dictComp) + ) - val callNode = - nodeBuilder.callNode( - code, - ".formatString", - DispatchTypes.STATIC_DISPATCH, - lineAndColOf(joinedString) - ) + contextStack.pop() - addAstChildrenAsArguments(callNode, 1, argumentNodes) + comprehensionBlockNode + end convert - callNode + /** Lowering of (x for y in l for x in y): { tmp = .genExp ( for y in l: for + * x in y: tmp.append(x) ) tmp } This lowering is not quite correct as it ignores the lazy + * evaluation of the generator expression. Instead it just mimics the list comprehension lowering + * but for now this is good enough. + */ + // TODO test + def convert(generatorExp: ast.GeneratorExp): NewNode = + contextStack.pushSpecialContext() + val tmpVariableName = getUnusedName() - def convert(constant: ast.Constant): nodes.NewNode = - constant.value match - case stringConstant: ast.StringConstant => - nodeBuilder.stringLiteralNode( - stringConstant.prefix + stringConstant.quote + stringConstant.value + stringConstant.quote, - lineAndColOf(constant) - ) - case stringConstant: ast.JoinedStringConstant => - nodeBuilder.stringLiteralNode(stringConstant.value, lineAndColOf(constant)) - case boolConstant: ast.BoolConstant => - val boolStr = if boolConstant.value then "True" else "False" - nodeBuilder.stringLiteralNode(boolStr, lineAndColOf(constant)) - case intConstant: ast.IntConstant => - nodeBuilder.numberLiteralNode(intConstant.value, lineAndColOf(constant)) - case floatConstant: ast.FloatConstant => - nodeBuilder.numberLiteralNode(floatConstant.value, lineAndColOf(constant)) - case imaginaryConstant: ast.ImaginaryConstant => - nodeBuilder.numberLiteralNode(imaginaryConstant.value + "j", lineAndColOf(constant)) - case ast.NoneConstant => - nodeBuilder.numberLiteralNode("None", lineAndColOf(constant)) - case ast.EllipsisConstant => - nodeBuilder.numberLiteralNode("...", lineAndColOf(constant)) - - /** TODO We currently ignore possible attribute access provider/interception mechanisms like - * __getattr__, __getattribute__ and __get__. - */ - def convert(attribute: ast.Attribute): nodes.NewNode = - val baseNode = convert(attribute.value) - val fieldName = attribute.attr - val lineAndCol = lineAndColOf(attribute) - - val fieldAccess = createFieldAccess(baseNode, fieldName, lineAndCol) - - attribute.value match - case name: ast.Name if name.id == "self" => - createAndRegisterMember(fieldName, lineAndCol) - case _ => - - fieldAccess - - private def createAndRegisterMember(name: String, lineAndCol: LineAndColumn): Unit = - contextStack.findEnclosingTypeDecl() match - case Some(typeDecl: NewTypeDecl) => - if !members.contains(typeDecl) || !members(typeDecl).contains(name) then - val member = nodeBuilder.memberNode(name, lineAndCol) - edgeBuilder.astEdge(member, typeDecl, contextStack.order.getAndInc) - members(typeDecl) = members.getOrElse(typeDecl, List()) ++ List(name) - case _ => - - def convert(subscript: ast.Subscript): NewNode = - createIndexAccess( - convert(subscript.value), - convert(subscript.slice), - lineAndColOf(subscript) + // Create tmp = list() + val genExpOperatorCall = + nodeBuilder.callNode( + ".genExp", + ".genExp", + DispatchTypes.STATIC_DISPATCH, + lineAndColOf(generatorExp) ) - def convert(starred: ast.Starred): NewNode = - val memoryOperation = memOpMap.get(starred).get - - memoryOperation match - case Load => - val unrollOperand = convert(starred.value) - createStarredUnpackOperatorCall(unrollOperand, lineAndColOf(starred)) - case Store => - unhandled(starred) - case Del => - // This case is not possible since star operator is not allowed in delete statement. - unhandled(starred) - - def convert(name: ast.Name): nodes.NewNode = - val memoryOperation = memOpMap.get(name).get - val identifier = createIdentifierNode(name.id, memoryOperation, lineAndColOf(name)) - if contextStack.isClassContext && memoryOperation == Store then - createAndRegisterMember(identifier.name, lineAndColOf(name)) - identifier - - // TODO test - def convert(list: ast.List): nodes.NewNode = - // Must be a List as part of a Load memory operation because a List literal - // is not permitted as argument to a Del and List as part of a Store does not - // reach here. - assert(memOpMap.get(list).get == Load) - val listElementNodes = list.elts.map(convert) - val code = listElementNodes.map(codeOf).mkString("[", ", ", "]") - - val callNode = - nodeBuilder.callNode( - code, - ".listLiteral", - DispatchTypes.STATIC_DISPATCH, - lineAndColOf(list) - ) + val variableAssignNode = + createAssignmentToIdentifier( + tmpVariableName, + genExpOperatorCall, + lineAndColOf(generatorExp) + ) - addAstChildrenAsArguments(callNode, 1, listElementNodes) + // Create tmp.append(x) + val genExpAppendCallNode = createXDotYCall( + () => createIdentifierNode(tmpVariableName, Load, lineAndColOf(generatorExp)), + "append", + xMayHaveSideEffects = false, + lineAndColOf(generatorExp), + convert(generatorExp.elt) :: Nil, + Nil + ) - callNode + val comprehensionBlockNode = createComprehensionLowering( + tmpVariableName, + variableAssignNode, + genExpAppendCallNode, + generatorExp.generators, + lineAndColOf(generatorExp) + ) - // TODO test - def convert(tuple: ast.Tuple): NewNode = - // Must be a tuple as part of a Load or Del memory operation because Tuples in - // store contexts are not supposed to reach here. They need to be lowered by - // createValueToTargetsDecomposition. - assert(memOpMap.get(tuple).get == Load || memOpMap.get(tuple).get == Del) - val tupleElementNodes = tuple.elts.map(convert) - val code = if tupleElementNodes.size != 1 then - tupleElementNodes.map(codeOf).mkString("(", ", ", ")") + contextStack.pop() + + comprehensionBlockNode + end convert + + def convert(await: ast.Await): NewNode = + // Since the CPG format does not provide means to model async/await, + // we for now treat it as non existing. + convert(await.value) + + def convert(yieldExpr: ast.Yield): NewNode = ??? + + def convert(yieldFrom: ast.YieldFrom): NewNode = ??? + + // In case of a single compare operation there is no lowering applied. + // So e.g. x < y stay untouched. + // Otherwise the lowering is as follows: + // Src AST: + // x < y < z < a + // Lowering: + // { + // tmp1 = y + // x < tmp1 && { + // tmp2 = z + // tmp1 < tmp2 && { + // tmp2 < a + // } + // } + // } + def convert(compare: ast.Compare): NewNode = + assert(compare.ops.size == compare.comparators.size) + var lhsNode = convert(compare.left) + + val topLevelExprNodes = + lowerComparatorChain(lhsNode, compare.ops, compare.comparators, lineAndColOf(compare)) + if topLevelExprNodes.size > 1 then + createBlock(topLevelExprNodes, lineAndColOf(compare)) + else + topLevelExprNodes.head + + private def compopToOpCodeAndFullName(compareOp: ast.icompop): () => (String, String) = () => + compareOp match + case ast.Eq => ("==", Operators.equals) + case ast.NotEq => ("!=", Operators.notEquals) + case ast.Lt => ("<", Operators.lessThan) + case ast.LtE => ("<=", Operators.lessEqualsThan) + case ast.Gt => (">", Operators.greaterThan) + case ast.GtE => (">=", Operators.greaterEqualsThan) + case ast.Is => ("is", ".is") + case ast.IsNot => ("is not", ".isNot") + case ast.In => ("in", ".in") + case ast.NotIn => ("not in", ".notIn") + + def lowerComparatorChain( + lhsNode: nodes.NewNode, + compOperators: Iterable[ast.icompop], + comparators: Iterable[ast.iexpr], + lineAndColumn: LineAndColumn + ): Iterable[nodes.NewNode] = + val rhsNode = convert(comparators.head) + + if compOperators.size == 1 then + val compareNode = + createBinaryOperatorCall( + lhsNode, + compopToOpCodeAndFullName(compOperators.head), + rhsNode, + lineAndColumn + ) + Iterable.single(compareNode) + else + val tmpVariableName = getUnusedName() + val assignmentNode = + createAssignmentToIdentifier(tmpVariableName, rhsNode, lineAndColumn) + + val tmpIdentifierCompare1 = createIdentifierNode(tmpVariableName, Load, lineAndColumn) + val compareNode = createBinaryOperatorCall( + lhsNode, + compopToOpCodeAndFullName(compOperators.head), + tmpIdentifierCompare1, + lineAndColumn + ) + + val tmpIdentifierCompare2 = createIdentifierNode(tmpVariableName, Load, lineAndColumn) + val childNodes = lowerComparatorChain( + tmpIdentifierCompare2, + compOperators.tail, + comparators.tail, + lineAndColumn + ) + + val blockNode = createBlock(childNodes, lineAndColumn) + + Iterable( + assignmentNode, + createBinaryOperatorCall( + compareNode, + andOpCodeAndFullName(), + blockNode, + lineAndColumn + ) + ) + end if + end lowerComparatorChain + + private def andOpCodeAndFullName(): () => (String, String) = () => + ("and", Operators.logicalAnd) + + /** TODO For now this function compromises on the correctness of the lowering in order to get some + * data flow tracking going. + * 1. For constructs like x.func() we assume x to be the instance which is passed into func. + * This is not true since the instance method object gets the instance already + * bound/captured during function access. This becomes relevant for constructs like: x.func + * \= y.func <- y.func is class method object x.func() In this case the instance passed into + * func is y and not x. We cannot represent this in th CPG and thus stick to the assumption + * that the part before the "." and the bound/captured instance will be the same. For + * reference see: + * https://docs.python.org/3/reference/datamodel.html#the-standard-type-hierarchy search for + * "Instance methods" + */ + def convert(call: ast.Call): nodes.NewNode = + val argumentNodes = call.args.map(convert).toSeq + val keywordArgNodes = call.keywords.flatMap { keyword => + if keyword.arg.isDefined then + Some((keyword.arg.get, convert(keyword.value))) else - "(" + codeOf(tupleElementNodes.head) + ",)" - - val callNode = - nodeBuilder.callNode( - code, - ".tupleLiteral", - DispatchTypes.STATIC_DISPATCH, - lineAndColOf(tuple) - ) + // keyword.arg == None. This is the case for func(**dict) style arguments. + // TODO implement handling for this case. + None + } + + call.func match + case attribute: ast.Attribute => + createXDotYCall( + () => convert(attribute.value), + attribute.attr, + xMayHaveSideEffects = !attribute.value.isInstanceOf[ast.Name], + lineAndColOf(call), + argumentNodes, + keywordArgNodes + ) + case _ => + val receiverNode = convert(call.func) + val name = call.func match + case ast.Name(id, _) => id + case _ => "" + createCall(receiverNode, name, lineAndColOf(call), argumentNodes, keywordArgNodes) + end convert + + def convert(formattedValue: ast.FormattedValue): nodes.NewNode = + val valueNode = convert(formattedValue.value) + + val equalSignStr = if formattedValue.equalSign then "=" else "" + val conversionStr = formattedValue.conversion match + case -1 => "" + case 115 => "!s" + case 114 => "!r" + case 97 => "!a" + + val formatSpecStr = formattedValue.format_spec match + case Some(formatSpec) => ":" + formatSpec + case None => "" + + val code = "{" + codeOf(valueNode) + equalSignStr + conversionStr + formatSpecStr + "}" + + val callNode = nodeBuilder.callNode( + code, + ".formattedValue", + DispatchTypes.STATIC_DISPATCH, + lineAndColOf(formattedValue) + ) - addAstChildrenAsArguments(callNode, 1, tupleElementNodes) + addAstChildrenAsArguments(callNode, 1, valueNode) - callNode - end convert + callNode + end convert - def convert(slice: ast.Slice): NewNode = ??? + def convert(joinedString: ast.JoinedString): nodes.NewNode = + val argumentNodes = joinedString.values.map(convert) - def convert(stringExpList: ast.StringExpList): NewNode = - val stringNodes = stringExpList.elts.map(convert) - val code = stringNodes.map(codeOf).mkString(" ") + val code = joinedString.prefix + joinedString.quote + argumentNodes + .map(codeOf) + .mkString("") + joinedString.quote - val callNode = nodeBuilder.callNode( + val callNode = + nodeBuilder.callNode( code, - ".stringExpressionList", + ".formatString", DispatchTypes.STATIC_DISPATCH, - lineAndColOf(stringExpList) + lineAndColOf(joinedString) ) - addAstChildrenAsArguments(callNode, 1, stringNodes) + addAstChildrenAsArguments(callNode, 1, argumentNodes) - callNode + callNode - // TODO Since there is now real concept of reflecting exception handlers - // semantically in the CPG we just make sure that the variable scoping - // is right and that we convert the exception handler body. - // TODO tests - def convert(exceptHandler: ast.ExceptHandler): NewNode = - contextStack.pushSpecialContext() - val specialTargetLocals = mutable.ArrayBuffer.empty[nodes.NewLocal] - if exceptHandler.name.isDefined then - val localNode = nodeBuilder.localNode(exceptHandler.name.get, None) - specialTargetLocals.append(localNode) - contextStack.addSpecialVariable(localNode) - - val blockNode = createBlock(exceptHandler.body.map(convert), lineAndColOf(exceptHandler)) - addAstChildNodes(blockNode, 1, specialTargetLocals) - - contextStack.pop() - - blockNode - - def convert(parameters: ast.Arguments, startIndex: Int): Iterable[nodes.NewMethodParameterIn] = - val autoIncIndex = new AutoIncIndex(startIndex) + def convert(constant: ast.Constant): nodes.NewNode = + constant.value match + case stringConstant: ast.StringConstant => + nodeBuilder.stringLiteralNode( + stringConstant.prefix + stringConstant.quote + stringConstant + .value + stringConstant.quote, + lineAndColOf(constant) + ) + case stringConstant: ast.JoinedStringConstant => + nodeBuilder.stringLiteralNode(stringConstant.value, lineAndColOf(constant)) + case boolConstant: ast.BoolConstant => + val boolStr = if boolConstant.value then "True" else "False" + nodeBuilder.stringLiteralNode(boolStr, lineAndColOf(constant)) + case intConstant: ast.IntConstant => + nodeBuilder.numberLiteralNode(intConstant.value, lineAndColOf(constant)) + case floatConstant: ast.FloatConstant => + nodeBuilder.numberLiteralNode(floatConstant.value, lineAndColOf(constant)) + case imaginaryConstant: ast.ImaginaryConstant => + nodeBuilder.numberLiteralNode(imaginaryConstant.value + "j", lineAndColOf(constant)) + case ast.NoneConstant => + nodeBuilder.numberLiteralNode("None", lineAndColOf(constant)) + case ast.EllipsisConstant => + nodeBuilder.numberLiteralNode("...", lineAndColOf(constant)) + + /** TODO We currently ignore possible attribute access provider/interception mechanisms like + * __getattr__, __getattribute__ and __get__. + */ + def convert(attribute: ast.Attribute): nodes.NewNode = + val baseNode = convert(attribute.value) + val fieldName = attribute.attr + val lineAndCol = lineAndColOf(attribute) + + val fieldAccess = createFieldAccess(baseNode, fieldName, lineAndCol) + + attribute.value match + case name: ast.Name if name.id == "self" => + createAndRegisterMember(fieldName, lineAndCol) + case _ => + + fieldAccess + + private def createAndRegisterMember(name: String, lineAndCol: LineAndColumn): Unit = + contextStack.findEnclosingTypeDecl() match + case Some(typeDecl: NewTypeDecl) => + if !members.contains(typeDecl) || !members(typeDecl).contains(name) then + val member = nodeBuilder.memberNode(name, lineAndCol) + edgeBuilder.astEdge(member, typeDecl, contextStack.order.getAndInc) + members(typeDecl) = members.getOrElse(typeDecl, List()) ++ List(name) + case _ => + + def convert(subscript: ast.Subscript): NewNode = + createIndexAccess( + convert(subscript.value), + convert(subscript.slice), + lineAndColOf(subscript) + ) + + def convert(starred: ast.Starred): NewNode = + val memoryOperation = memOpMap.get(starred).get + + memoryOperation match + case Load => + val unrollOperand = convert(starred.value) + createStarredUnpackOperatorCall(unrollOperand, lineAndColOf(starred)) + case Store => + unhandled(starred) + case Del => + // This case is not possible since star operator is not allowed in delete statement. + unhandled(starred) + + def convert(name: ast.Name): nodes.NewNode = + val memoryOperation = memOpMap.get(name).get + val identifier = createIdentifierNode(name.id, memoryOperation, lineAndColOf(name)) + if contextStack.isClassContext && memoryOperation == Store then + createAndRegisterMember(identifier.name, lineAndColOf(name)) + identifier + + // TODO test + def convert(list: ast.List): nodes.NewNode = + // Must be a List as part of a Load memory operation because a List literal + // is not permitted as argument to a Del and List as part of a Store does not + // reach here. + assert(memOpMap.get(list).get == Load) + val listElementNodes = list.elts.map(convert) + val code = listElementNodes.map(codeOf).mkString("[", ", ", "]") + + val callNode = + nodeBuilder.callNode( + code, + ".listLiteral", + DispatchTypes.STATIC_DISPATCH, + lineAndColOf(list) + ) - parameters.posonlyargs.map(convertPosOnlyArg(_, autoIncIndex)) ++ - parameters.args.map(convertNormalArg(_, autoIncIndex)) ++ - parameters.vararg.map(convertVarArg(_, autoIncIndex)) ++ - parameters.kwonlyargs.map(convertKeywordOnlyArg) ++ - parameters.kw_arg.map(convertKwArg) + addAstChildrenAsArguments(callNode, 1, listElementNodes) - // TODO for now the different arg convert functions are all the same but - // will all be slightly different in the future when we can represent the - // different types in the cpg. - def convertPosOnlyArg(arg: ast.Arg, index: AutoIncIndex): nodes.NewMethodParameterIn = - nodeBuilder.methodParameterNode( - arg.arg, - isVariadic = false, - lineAndColOf(arg), - Option(index.getAndInc), - arg.annotation - ) + callNode - def convertNormalArg(arg: ast.Arg, index: AutoIncIndex): nodes.NewMethodParameterIn = - nodeBuilder.methodParameterNode( - arg.arg, - isVariadic = false, - lineAndColOf(arg), - Option(index.getAndInc), - arg.annotation - ) + // TODO test + def convert(tuple: ast.Tuple): NewNode = + // Must be a tuple as part of a Load or Del memory operation because Tuples in + // store contexts are not supposed to reach here. They need to be lowered by + // createValueToTargetsDecomposition. + assert(memOpMap.get(tuple).get == Load || memOpMap.get(tuple).get == Del) + val tupleElementNodes = tuple.elts.map(convert) + val code = if tupleElementNodes.size != 1 then + tupleElementNodes.map(codeOf).mkString("(", ", ", ")") + else + "(" + codeOf(tupleElementNodes.head) + ",)" - def convertVarArg(arg: ast.Arg, index: AutoIncIndex): nodes.NewMethodParameterIn = - nodeBuilder.methodParameterNode( - arg.arg, - isVariadic = true, - lineAndColOf(arg), - Option(index.getAndInc) + val callNode = + nodeBuilder.callNode( + code, + ".tupleLiteral", + DispatchTypes.STATIC_DISPATCH, + lineAndColOf(tuple) ) - def convertKeywordOnlyArg(arg: ast.Arg): nodes.NewMethodParameterIn = - nodeBuilder.methodParameterNode(arg.arg, isVariadic = false, lineAndColOf(arg)) + addAstChildrenAsArguments(callNode, 1, tupleElementNodes) - def convertKwArg(arg: ast.Arg): nodes.NewMethodParameterIn = - nodeBuilder.methodParameterNode(arg.arg, isVariadic = false, lineAndColOf(arg)) + callNode + end convert - def convert(keyword: ast.Keyword): NewNode = ??? + def convert(slice: ast.Slice): NewNode = ??? - def convert(alias: ast.Alias): NewNode = ??? + def convert(stringExpList: ast.StringExpList): NewNode = + val stringNodes = stringExpList.elts.map(convert) + val code = stringNodes.map(codeOf).mkString(" ") - def convert(typeIgnore: ast.TypeIgnore): NewNode = ??? + val callNode = nodeBuilder.callNode( + code, + ".stringExpressionList", + DispatchTypes.STATIC_DISPATCH, + lineAndColOf(stringExpList) + ) - private def calculateFullNameFromContext(name: String): String = - val contextQualName = contextStack.qualName - if contextQualName != "" then - relFileName + ":" + contextQualName + "." + name - else - relFileName + ":" + name + addAstChildrenAsArguments(callNode, 1, stringNodes) + + callNode + + // TODO Since there is now real concept of reflecting exception handlers + // semantically in the CPG we just make sure that the variable scoping + // is right and that we convert the exception handler body. + // TODO tests + def convert(exceptHandler: ast.ExceptHandler): NewNode = + contextStack.pushSpecialContext() + val specialTargetLocals = mutable.ArrayBuffer.empty[nodes.NewLocal] + if exceptHandler.name.isDefined then + val localNode = nodeBuilder.localNode(exceptHandler.name.get, None) + specialTargetLocals.append(localNode) + contextStack.addSpecialVariable(localNode) + + val blockNode = createBlock(exceptHandler.body.map(convert), lineAndColOf(exceptHandler)) + addAstChildNodes(blockNode, 1, specialTargetLocals) + + contextStack.pop() + + blockNode + + def convert(parameters: ast.Arguments, startIndex: Int): Iterable[nodes.NewMethodParameterIn] = + val autoIncIndex = new AutoIncIndex(startIndex) + + parameters.posonlyargs.map(convertPosOnlyArg(_, autoIncIndex)) ++ + parameters.args.map(convertNormalArg(_, autoIncIndex)) ++ + parameters.vararg.map(convertVarArg(_, autoIncIndex)) ++ + parameters.kwonlyargs.map(convertKeywordOnlyArg) ++ + parameters.kw_arg.map(convertKwArg) + + // TODO for now the different arg convert functions are all the same but + // will all be slightly different in the future when we can represent the + // different types in the cpg. + def convertPosOnlyArg(arg: ast.Arg, index: AutoIncIndex): nodes.NewMethodParameterIn = + nodeBuilder.methodParameterNode( + arg.arg, + isVariadic = false, + lineAndColOf(arg), + Option(index.getAndInc), + arg.annotation + ) + + def convertNormalArg(arg: ast.Arg, index: AutoIncIndex): nodes.NewMethodParameterIn = + nodeBuilder.methodParameterNode( + arg.arg, + isVariadic = false, + lineAndColOf(arg), + Option(index.getAndInc), + arg.annotation + ) + + def convertVarArg(arg: ast.Arg, index: AutoIncIndex): nodes.NewMethodParameterIn = + nodeBuilder.methodParameterNode( + arg.arg, + isVariadic = true, + lineAndColOf(arg), + Option(index.getAndInc) + ) + + def convertKeywordOnlyArg(arg: ast.Arg): nodes.NewMethodParameterIn = + nodeBuilder.methodParameterNode(arg.arg, isVariadic = false, lineAndColOf(arg)) + + def convertKwArg(arg: ast.Arg): nodes.NewMethodParameterIn = + nodeBuilder.methodParameterNode(arg.arg, isVariadic = false, lineAndColOf(arg)) + + def convert(keyword: ast.Keyword): NewNode = ??? + + def convert(alias: ast.Alias): NewNode = ??? + + def convert(typeIgnore: ast.TypeIgnore): NewNode = ??? + + private def calculateFullNameFromContext(name: String): String = + val contextQualName = contextStack.qualName + if contextQualName != "" then + relFileName + ":" + contextQualName + "." + name + else + relFileName + ":" + name end PythonAstVisitor object PythonAstVisitor: - val builtinPrefix = "__builtin." - val typingPrefix = "typing." - val metaClassSuffix = "" - - // This list contains all functions from https://docs.python.org/3/library/functions.html#built-in-funcs - // for python version 3.9.5. - // There is a corresponding list in policies which needs to be updated if this one is updated and vice versa. - val builtinFunctionsV3: Iterable[String] = Iterable( - "abs", - "aiter", - "all", - "anext", - "any", - "ascii", - "bin", - "breakpoint", - "callable", - "chr", - "classmethod", - "compile", - "delattr", - "dir", - "divmod", - "enumerate", - "eval", - "exec", - "filter", - "format", - "getattr", - "globals", - "hasattr", - "hash", - "help", - "hex", - "id", - "input", - "isinstance", - "issubclass", - "iter", - "len", - "locals", - "map", - "max", - "memoryview", - "min", - "next", - "oct", - "open", - "ord", - "pow", - "print", - "repr", - "reversed", - "round", - "setattr", - "sorted", - "staticmethod", - "sum", - "super", - "vars", - "zip", - "__import__" - ) - // This list contains all classes from https://docs.python.org/3/library/functions.html#built-in-funcs - // for python version 3.9.5. - val builtinClassesV3: Iterable[String] = Iterable( - "bool", - "bytearray", - "bytes", - "complex", - "dict", - "float", - "frozenset", - "int", - "list", - "memoryview", - "object", - "property", - "range", - "set", - "slice", - "str", - "tuple", - "type" - ) - // This list contains all functions from https://docs.python.org/2.7/library/functions.html - val builtinFunctionsV2: Iterable[String] = Iterable( - "abs", - "all", - "any", - "bin", - "callable", - "chr", - "classmethod", - "cmp", - "compile", - "delattr", - "dir", - "divmod", - "enumerate", - "eval", - // This one is special because it is not from the above mentioned list. - // This is because exec is a statement type in V2 but our parser provides - // it to us as a normal call so that we can model it as builtin. - "exec", - "execfile", - "filter", - "format", - "getattr", - "globals", - "hasattr", - "hash", - "help", - "hex", - "id", - "input", - "isinstance", - "issubclass", - "iter", - "len", - "locals", - "map", - "max", - "min", - "next", - "oct", - "open", - "ord", - "pow", - "print", - "range", - "raw_input", - "reduce", - "reload", - "repr", - "reversed", - "round", - "setattr", - "sorted", - "staticmethod", - "sum", - "super", - "unichr", - "vars", - "zip", - "__import__" - ) - // This list contains all classes from https://docs.python.org/2.7/library/functions.html - val builtinClassesV2: Iterable[String] = Iterable( - "bool", - "bytearray", - "complex", - "dict", - "file", - "float", - "frozenset", - "int", - "list", - "long", - "memoryview", - "object", - "property", - "set", - "slice", - "str", - "tuple", - "type", - "unicode", - "xrange" - ) - - lazy val allBuiltinClasses: Set[String] = (builtinClassesV2 ++ builtinClassesV3).toSet - - lazy val typingClassesV3: Set[String] = Set( - "Annotated", - "Any", - "Callable", - "ClassVar", - "Final", - "ForwardRef", - "Generic", - "Literal", - "Optional", - "Protocol", - "Tuple", - "Type", - "TypeVar", - "Union", - "AbstractSet", - "ByteString", - "Container", - "ContextManager", - "Hashable", - "ItemsView", - "Iterable", - "Iterator", - "KeysView", - "Mapping", - "MappingView", - "MutableMapping", - "MutableSequence", - "MutableSet", - "Sequence", - "Sized", - "ValuesView", - "Awaitable", - "AsyncIterator", - "AsyncIterable", - "Coroutine", - "Collection", - "AsyncGenerator", - "AsyncContextManager", - "Reversible", - "SupportsAbs", - "SupportsBytes", - "SupportsComplex", - "SupportsFloat", - "SupportsIndex", - "SupportsInt", - "SupportsRound", - "ChainMap", - "Counter", - "Deque", - "Dict", - "DefaultDict", - "List", - "OrderedDict", - "Set", - "FrozenSet", - "NamedTuple", - "TypedDict", - "Generator", - "BinaryIO", - "IO", - "Match", - "Pattern", - "TextIO", - "AnyStr", - "cast", - "final", - "get_args", - "get_origin", - "get_type_hints", - "NewType", - "no_type_check", - "no_type_check_decorator", - "NoReturn", - "overload", - "runtime_checkable", - "Text", - "TYPE_CHECKING" - ) + val builtinPrefix = "__builtin." + val typingPrefix = "typing." + val metaClassSuffix = "" + + // This list contains all functions from https://docs.python.org/3/library/functions.html#built-in-funcs + // for python version 3.9.5. + // There is a corresponding list in policies which needs to be updated if this one is updated and vice versa. + val builtinFunctionsV3: Iterable[String] = Iterable( + "abs", + "aiter", + "all", + "anext", + "any", + "ascii", + "bin", + "breakpoint", + "callable", + "chr", + "classmethod", + "compile", + "delattr", + "dir", + "divmod", + "enumerate", + "eval", + "exec", + "filter", + "format", + "getattr", + "globals", + "hasattr", + "hash", + "help", + "hex", + "id", + "input", + "isinstance", + "issubclass", + "iter", + "len", + "locals", + "map", + "max", + "memoryview", + "min", + "next", + "oct", + "open", + "ord", + "pow", + "print", + "repr", + "reversed", + "round", + "setattr", + "sorted", + "staticmethod", + "sum", + "super", + "vars", + "zip", + "__import__" + ) + // This list contains all classes from https://docs.python.org/3/library/functions.html#built-in-funcs + // for python version 3.9.5. + val builtinClassesV3: Iterable[String] = Iterable( + "bool", + "bytearray", + "bytes", + "complex", + "dict", + "float", + "frozenset", + "int", + "list", + "memoryview", + "object", + "property", + "range", + "set", + "slice", + "str", + "tuple", + "type" + ) + // This list contains all functions from https://docs.python.org/2.7/library/functions.html + val builtinFunctionsV2: Iterable[String] = Iterable( + "abs", + "all", + "any", + "bin", + "callable", + "chr", + "classmethod", + "cmp", + "compile", + "delattr", + "dir", + "divmod", + "enumerate", + "eval", + // This one is special because it is not from the above mentioned list. + // This is because exec is a statement type in V2 but our parser provides + // it to us as a normal call so that we can model it as builtin. + "exec", + "execfile", + "filter", + "format", + "getattr", + "globals", + "hasattr", + "hash", + "help", + "hex", + "id", + "input", + "isinstance", + "issubclass", + "iter", + "len", + "locals", + "map", + "max", + "min", + "next", + "oct", + "open", + "ord", + "pow", + "print", + "range", + "raw_input", + "reduce", + "reload", + "repr", + "reversed", + "round", + "setattr", + "sorted", + "staticmethod", + "sum", + "super", + "unichr", + "vars", + "zip", + "__import__" + ) + // This list contains all classes from https://docs.python.org/2.7/library/functions.html + val builtinClassesV2: Iterable[String] = Iterable( + "bool", + "bytearray", + "complex", + "dict", + "file", + "float", + "frozenset", + "int", + "list", + "long", + "memoryview", + "object", + "property", + "set", + "slice", + "str", + "tuple", + "type", + "unicode", + "xrange" + ) + + lazy val allBuiltinClasses: Set[String] = (builtinClassesV2 ++ builtinClassesV3).toSet + + lazy val typingClassesV3: Set[String] = Set( + "Annotated", + "Any", + "Callable", + "ClassVar", + "Final", + "ForwardRef", + "Generic", + "Literal", + "Optional", + "Protocol", + "Tuple", + "Type", + "TypeVar", + "Union", + "AbstractSet", + "ByteString", + "Container", + "ContextManager", + "Hashable", + "ItemsView", + "Iterable", + "Iterator", + "KeysView", + "Mapping", + "MappingView", + "MutableMapping", + "MutableSequence", + "MutableSet", + "Sequence", + "Sized", + "ValuesView", + "Awaitable", + "AsyncIterator", + "AsyncIterable", + "Coroutine", + "Collection", + "AsyncGenerator", + "AsyncContextManager", + "Reversible", + "SupportsAbs", + "SupportsBytes", + "SupportsComplex", + "SupportsFloat", + "SupportsIndex", + "SupportsInt", + "SupportsRound", + "ChainMap", + "Counter", + "Deque", + "Dict", + "DefaultDict", + "List", + "OrderedDict", + "Set", + "FrozenSet", + "NamedTuple", + "TypedDict", + "Generator", + "BinaryIO", + "IO", + "Match", + "Pattern", + "TextIO", + "AnyStr", + "cast", + "final", + "get_args", + "get_origin", + "get_type_hints", + "NewType", + "no_type_check", + "no_type_check_decorator", + "NoReturn", + "overload", + "runtime_checkable", + "Text", + "TYPE_CHECKING" + ) end PythonAstVisitor diff --git a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/PythonAstVisitorHelpers.scala b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/PythonAstVisitorHelpers.scala index 4d4ab3f6..209ce3e8 100644 --- a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/PythonAstVisitorHelpers.scala +++ b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/PythonAstVisitorHelpers.scala @@ -9,665 +9,665 @@ import scala.collection.immutable.{::, Nil} import scala.collection.mutable trait PythonAstVisitorHelpers: - this: PythonAstVisitor => - - protected def codeOf(node: NewNode): String = - node.asInstanceOf[AstNodeNew].code - - protected def lineAndColOf(node: ast.iattributes): LineAndColumn = - // node.end_col_offset - 1 because the end column offset of the parser points - // behind the last symbol. - LineAndColumn(node.lineno, node.col_offset, node.end_lineno, node.end_col_offset - 1) - - private var tmpCounter = 0 - - protected def getUnusedName(prefix: String = null): String = - // TODO check that result name does not collide with existing variables. - val result = if prefix != null then - s"${prefix}_tmp$tmpCounter" - else - s"tmp$tmpCounter" - tmpCounter += 1 - result - - protected def createTry( - body: Iterable[NewNode], - handlers: Iterable[NewNode], - finalBlock: Iterable[NewNode], - orElseBlock: Iterable[NewNode], - lineAndColumn: LineAndColumn - ): NewNode = - val controlStructureNode = - nodeBuilder.controlStructureNode("try: ...", ControlStructureTypes.TRY, lineAndColumn) - - val bodyBlockNode = createBlock(body, lineAndColumn) - val handlersBlockNode = createBlock(handlers, lineAndColumn) - val finalBlockNode = createBlock(finalBlock, lineAndColumn) - val orElseBlockNode = createBlock(orElseBlock, lineAndColumn) - - addAstChildNodes( - controlStructureNode, - 1, - bodyBlockNode, - handlersBlockNode, - finalBlockNode, - orElseBlockNode - ) - - controlStructureNode - end createTry - - protected def createTransformedImport( - from: String, - names: Iterable[ast.Alias], - lineAndCol: LineAndColumn - ): NewNode = - val importAssignNodes = - names.map { alias => - val importedAsIdentifierName = alias.asName.getOrElse(alias.name) - val importAssignLhsIdentifierNode = - createIdentifierNode(importedAsIdentifierName, Store, lineAndCol) - - val arguments = Seq( - nodeBuilder.stringLiteralNode(from, lineAndCol), - nodeBuilder.stringLiteralNode(alias.name, lineAndCol) - ) ++ (alias.asName match - case Some(aliasName) => - Seq(nodeBuilder.stringLiteralNode(aliasName, lineAndCol)) - case None => Seq() - ) - - val importCallNode = - createCall( - createIdentifierNode("import", Load, lineAndCol), - "import", - lineAndCol, - arguments, - Nil - ) - - val assignNode = - createAssignment(importAssignLhsIdentifierNode, importCallNode, lineAndCol) - assignNode - } - - if importAssignNodes.size > 1 then - createBlock(importAssignNodes, lineAndCol) - else - // Empty importAssignNodes cannot happen. - importAssignNodes.head - end createTransformedImport - - // Used for assign statements, for loop target assignment and - // for comprehension target assignment. - // TODO handle Starred target - protected def createValueToTargetsDecomposition( - targets: Iterable[ast.iexpr], - valueNode: NewNode, - lineAndColumn: LineAndColumn - ): Iterable[NewNode] = - if - targets.size == 1 && - !targets.head.isInstanceOf[ast.Tuple] && - !targets.head.isInstanceOf[ast.List] - then - // No lowering or wrapping in a block is required if we have a single target and - // no decomposition. - val targetNode = convert(targets.head) - - Iterable.single(createAssignment(targetNode, valueNode, lineAndColumn)) - else - // Lowering of x, (y,z) = a = b = c: - // Note: No surrounding block is created. This is the duty of the caller. - // tmp = c - // x = tmp[0] - // y = tmp[1][0] - // z = tmp[1][1] - // a = c - // b = c - // Lowering of for x, (y, z) in c: - // tmp = c - // x = tmp[0] - // y = tmp[1][0] - // z = tmp[1][1] - val tmpVariableName = getUnusedName() - - val tmpVariableAssignNode = - createAssignmentToIdentifier(tmpVariableName, valueNode, lineAndColumn) - - val loweredAssignNodes = mutable.ArrayBuffer.empty[NewNode] - loweredAssignNodes.append(tmpVariableAssignNode) - - targets.foreach { target => - val targetWithAccessChains = getTargetsWithAccessChains(target) - targetWithAccessChains.foreach { case (trgt, accessChain) => - val targetNode = convert(trgt) - val tmpIdentifierNode = - createIdentifierNode(tmpVariableName, Load, lineAndColumn) - val indexTmpIdentifierNode = - createIndexAccessChain(tmpIdentifierNode, accessChain, lineAndColumn) - - val targetAssignNode = - createAssignment(targetNode, indexTmpIdentifierNode, lineAndColumn) - loweredAssignNodes.append(targetAssignNode) - } - } - loweredAssignNodes - - protected def getTargetsWithAccessChains(target: ast.iexpr): Iterable[(ast.iexpr, List[Int])] = - val result = mutable.ArrayBuffer.empty[(ast.iexpr, List[Int])] - getTargetsInternal(target, Nil) - - def getTargetsInternal(target: ast.iexpr, indexChain: List[Int]): Unit = - target match - case tuple: ast.Tuple => - var index = 0 - tuple.elts.foreach { element => - getTargetsInternal(element, index :: indexChain) - index += 1 - } - case list: ast.List => - var index = 0 - list.elts.foreach { element => - getTargetsInternal(element, index :: indexChain) - index += 1 - } - case _ => - result.append((target, indexChain)) - - result - end getTargetsWithAccessChains - - protected def createComprehensionLowering( - tmpVariableName: String, - containerInitAssignNode: NewNode, - innerMostLoopNode: NewNode, - comprehensions: Iterable[ast.Comprehension], - lineAndColumn: LineAndColumn - ): NewNode = - val specialTargetLocals = mutable.ArrayBuffer.empty[NewLocal] - - // Innermost generator is transformed first and becomes the body of the - // generator one layer up. The body of the innermost generator is the - // list comprehensions element expression wrapped in an tmp.append() call. - val nestedLoopBlockNode = - comprehensions.foldRight(innerMostLoopNode) { case (comprehension, loopBodyNode) => - extractComprehensionSpecialVariableNames(comprehension.target).foreach { name => - // For the target names we need to create special scoped variables. - val localNode = nodeBuilder.localNode(name.id, None) - specialTargetLocals.append(localNode) - contextStack.addSpecialVariable(localNode) - } - createForLowering( - comprehension.target, - comprehension.iter, - comprehension.ifs, - Iterable.single(loopBodyNode), - Iterable.empty, - comprehension.is_async, - lineAndColumn - ) - } - - val returnIdentifierNode = createIdentifierNode(tmpVariableName, Load, lineAndColumn) - - val blockNode = - createBlock( - containerInitAssignNode :: nestedLoopBlockNode :: returnIdentifierNode :: Nil, - lineAndColumn - ) - - addAstChildNodes(blockNode, 1, specialTargetLocals) - - blockNode - end createComprehensionLowering - - // Extracts plain names, starred names and name or starred name elements from tuples and lists. - private def extractComprehensionSpecialVariableNames(target: ast.iexpr): Iterable[ast.Name] = - target match - case name: ast.Name => - name :: Nil - case starred: ast.Starred => - extractComprehensionSpecialVariableNames(starred.value) - case tuple: ast.Tuple => - tuple.elts.flatMap(extractComprehensionSpecialVariableNames) - case list: ast.List => - list.elts.flatMap(extractComprehensionSpecialVariableNames) - case _ => + this: PythonAstVisitor => + + protected def codeOf(node: NewNode): String = + node.asInstanceOf[AstNodeNew].code + + protected def lineAndColOf(node: ast.iattributes): LineAndColumn = + // node.end_col_offset - 1 because the end column offset of the parser points + // behind the last symbol. + LineAndColumn(node.lineno, node.col_offset, node.end_lineno, node.end_col_offset - 1) + + private var tmpCounter = 0 + + protected def getUnusedName(prefix: String = null): String = + // TODO check that result name does not collide with existing variables. + val result = if prefix != null then + s"${prefix}_tmp$tmpCounter" + else + s"tmp$tmpCounter" + tmpCounter += 1 + result + + protected def createTry( + body: Iterable[NewNode], + handlers: Iterable[NewNode], + finalBlock: Iterable[NewNode], + orElseBlock: Iterable[NewNode], + lineAndColumn: LineAndColumn + ): NewNode = + val controlStructureNode = + nodeBuilder.controlStructureNode("try: ...", ControlStructureTypes.TRY, lineAndColumn) + + val bodyBlockNode = createBlock(body, lineAndColumn) + val handlersBlockNode = createBlock(handlers, lineAndColumn) + val finalBlockNode = createBlock(finalBlock, lineAndColumn) + val orElseBlockNode = createBlock(orElseBlock, lineAndColumn) + + addAstChildNodes( + controlStructureNode, + 1, + bodyBlockNode, + handlersBlockNode, + finalBlockNode, + orElseBlockNode + ) + + controlStructureNode + end createTry + + protected def createTransformedImport( + from: String, + names: Iterable[ast.Alias], + lineAndCol: LineAndColumn + ): NewNode = + val importAssignNodes = + names.map { alias => + val importedAsIdentifierName = alias.asName.getOrElse(alias.name) + val importAssignLhsIdentifierNode = + createIdentifierNode(importedAsIdentifierName, Store, lineAndCol) + + val arguments = Seq( + nodeBuilder.stringLiteralNode(from, lineAndCol), + nodeBuilder.stringLiteralNode(alias.name, lineAndCol) + ) ++ (alias.asName match + case Some(aliasName) => + Seq(nodeBuilder.stringLiteralNode(aliasName, lineAndCol)) + case None => Seq() + ) + + val importCallNode = + createCall( + createIdentifierNode("import", Load, lineAndCol), + "import", + lineAndCol, + arguments, Nil + ) - protected def createBlock( - blockElements: Iterable[NewNode], - lineAndColumn: LineAndColumn - ): NewNode = - val blockCode = blockElements.map(codeOf).mkString("\n") - val blockNode = nodeBuilder.blockNode(blockCode, lineAndColumn) - - val orderIndex = new AutoIncIndex(1) - addAstChildNodes(blockNode, orderIndex, blockElements) - - blockNode - - protected def createCall( - receiverNode: NewNode, - name: String, - lineAndColumn: LineAndColumn, - argumentNodes: Iterable[NewNode], - keywordArguments: Iterable[(String, NewNode)] - ): NewCall = - val code = codeOf(receiverNode) + - "(" + - argumentNodes.map(codeOf).mkString(", ") + - (if argumentNodes.nonEmpty && keywordArguments.nonEmpty then ", " else "") + - keywordArguments - .map { case (keyword: String, argNode) => keyword + " = " + codeOf(argNode) } - .mkString(", ") + - ")" - val callNode = - nodeBuilder.callNode(code, name, DispatchTypes.DYNAMIC_DISPATCH, lineAndColumn) - - edgeBuilder.astEdge(receiverNode, callNode, 0) - edgeBuilder.receiverEdge(receiverNode, callNode) - - var index = 1 - argumentNodes.foreach { argumentNode => - edgeBuilder.astEdge(argumentNode, callNode, order = index) - edgeBuilder.argumentEdge(argumentNode, callNode, argIndex = index) - index += 1 - } - - keywordArguments.foreach { case (keyword: String, argumentNode) => - edgeBuilder.astEdge(argumentNode, callNode, order = index) - edgeBuilder.argumentEdge(argumentNode, callNode, argName = keyword) - index += 1 + val assignNode = + createAssignment(importAssignLhsIdentifierNode, importCallNode, lineAndCol) + assignNode } - callNode - end createCall - - protected def createInstanceCall( - receiverNode: NewNode, - instanceNode: NewNode, - name: String, - lineAndColumn: LineAndColumn, - argumentNodes: Iterable[NewNode], - keywordArguments: Iterable[(String, NewNode)] - ): NewCall = - val code = codeOf(receiverNode) + - "(" + - argumentNodes.map(codeOf).mkString(", ") + - (if argumentNodes.nonEmpty && keywordArguments.nonEmpty then ", " else "") + - keywordArguments - .map { case (keyword: String, argNode) => keyword + " = " + codeOf(argNode) } - .mkString(", ") + - ")" - val callNode = - nodeBuilder.callNode(code, name, DispatchTypes.DYNAMIC_DISPATCH, lineAndColumn) - - edgeBuilder.astEdge(receiverNode, callNode, 0) - edgeBuilder.receiverEdge(receiverNode, callNode) - edgeBuilder.astEdge(instanceNode, callNode, 1) - edgeBuilder.argumentEdge(instanceNode, callNode, 0) - - var argIndex = 1 - argumentNodes.foreach { argumentNode => - edgeBuilder.astEdge(argumentNode, callNode, argIndex + 1) - edgeBuilder.argumentEdge(argumentNode, callNode, argIndex) - argIndex += 1 + if importAssignNodes.size > 1 then + createBlock(importAssignNodes, lineAndCol) + else + // Empty importAssignNodes cannot happen. + importAssignNodes.head + end createTransformedImport + + // Used for assign statements, for loop target assignment and + // for comprehension target assignment. + // TODO handle Starred target + protected def createValueToTargetsDecomposition( + targets: Iterable[ast.iexpr], + valueNode: NewNode, + lineAndColumn: LineAndColumn + ): Iterable[NewNode] = + if + targets.size == 1 && + !targets.head.isInstanceOf[ast.Tuple] && + !targets.head.isInstanceOf[ast.List] + then + // No lowering or wrapping in a block is required if we have a single target and + // no decomposition. + val targetNode = convert(targets.head) + + Iterable.single(createAssignment(targetNode, valueNode, lineAndColumn)) + else + // Lowering of x, (y,z) = a = b = c: + // Note: No surrounding block is created. This is the duty of the caller. + // tmp = c + // x = tmp[0] + // y = tmp[1][0] + // z = tmp[1][1] + // a = c + // b = c + // Lowering of for x, (y, z) in c: + // tmp = c + // x = tmp[0] + // y = tmp[1][0] + // z = tmp[1][1] + val tmpVariableName = getUnusedName() + + val tmpVariableAssignNode = + createAssignmentToIdentifier(tmpVariableName, valueNode, lineAndColumn) + + val loweredAssignNodes = mutable.ArrayBuffer.empty[NewNode] + loweredAssignNodes.append(tmpVariableAssignNode) + + targets.foreach { target => + val targetWithAccessChains = getTargetsWithAccessChains(target) + targetWithAccessChains.foreach { case (trgt, accessChain) => + val targetNode = convert(trgt) + val tmpIdentifierNode = + createIdentifierNode(tmpVariableName, Load, lineAndColumn) + val indexTmpIdentifierNode = + createIndexAccessChain(tmpIdentifierNode, accessChain, lineAndColumn) + + val targetAssignNode = + createAssignment(targetNode, indexTmpIdentifierNode, lineAndColumn) + loweredAssignNodes.append(targetAssignNode) + } } + loweredAssignNodes - keywordArguments.foreach { case (keyword: String, argumentNode) => - edgeBuilder.astEdge(argumentNode, callNode, order = argIndex + 1) - edgeBuilder.argumentEdge(argumentNode, callNode, argName = keyword) - argIndex += 1 - } + protected def getTargetsWithAccessChains(target: ast.iexpr): Iterable[(ast.iexpr, List[Int])] = + val result = mutable.ArrayBuffer.empty[(ast.iexpr, List[Int])] + getTargetsInternal(target, Nil) - callNode - end createInstanceCall - - // NOTE if xMayHaveSideEffects == false, function x must return a distinct - // tree/node for each invocation!!! - // Otherwise the same tree/node may get placed in different places of the AST - // which is invalid and in this concrete case here triggered setting the - // argumentIndex twice once for x being receiver and once for x being the - // instance. - // If x may have side effects we lower as follows: x.y() => - // { - // tmp = x - // CALL(recv = tmp.y, inst = tmp, args=) - // } - protected def createXDotYCall( - x: () => NewNode, - y: String, - xMayHaveSideEffects: Boolean, - lineAndColumn: LineAndColumn, - argumentNodes: Iterable[NewNode], - keywordArguments: Iterable[(String, NewNode)] - ): NewNode = - if xMayHaveSideEffects then - val tmpVarName = getUnusedName() - val tmpAssignCall = createAssignmentToIdentifier(tmpVarName, x(), lineAndColumn) - val receiverNode = - createFieldAccess( - createIdentifierNode(tmpVarName, Load, lineAndColumn), - y, - lineAndColumn - ) - val instanceNode = createIdentifierNode(tmpVarName, Load, lineAndColumn) - val instanceCallNode = - createInstanceCall( - receiverNode, - instanceNode, - y, - lineAndColumn, - argumentNodes, - keywordArguments - ) - createBlock(tmpAssignCall :: instanceCallNode :: Nil, lineAndColumn) - else - val receiverNode = createFieldAccess(x(), y, lineAndColumn) - createInstanceCall(receiverNode, x(), y, lineAndColumn, argumentNodes, keywordArguments) - - // NOTE: The argument indicies start from 0! - protected def createStaticCall( - name: String, - methodFullName: String, - lineAndColumn: LineAndColumn, - argumentNodes: Iterable[NewNode], - keywordArguments: Iterable[(String, NewNode)] - ): NewNode = - val code = name + - "(" + - argumentNodes.map(codeOf).mkString(", ") + - (if argumentNodes.nonEmpty && keywordArguments.nonEmpty then ", " else "") + - keywordArguments - .map { case (keyword: String, argNode) => keyword + " = " + codeOf(argNode) } - .mkString(", ") + - ")" - val callNode = - nodeBuilder.callNode(code, methodFullName, DispatchTypes.STATIC_DISPATCH, lineAndColumn) - - var argIndex = 0 - argumentNodes.foreach { argumentNode => - edgeBuilder.astEdge(argumentNode, callNode, argIndex) - edgeBuilder.argumentEdge(argumentNode, callNode, argIndex) - argIndex += 1 + def getTargetsInternal(target: ast.iexpr, indexChain: List[Int]): Unit = + target match + case tuple: ast.Tuple => + var index = 0 + tuple.elts.foreach { element => + getTargetsInternal(element, index :: indexChain) + index += 1 + } + case list: ast.List => + var index = 0 + list.elts.foreach { element => + getTargetsInternal(element, index :: indexChain) + index += 1 + } + case _ => + result.append((target, indexChain)) + + result + end getTargetsWithAccessChains + + protected def createComprehensionLowering( + tmpVariableName: String, + containerInitAssignNode: NewNode, + innerMostLoopNode: NewNode, + comprehensions: Iterable[ast.Comprehension], + lineAndColumn: LineAndColumn + ): NewNode = + val specialTargetLocals = mutable.ArrayBuffer.empty[NewLocal] + + // Innermost generator is transformed first and becomes the body of the + // generator one layer up. The body of the innermost generator is the + // list comprehensions element expression wrapped in an tmp.append() call. + val nestedLoopBlockNode = + comprehensions.foldRight(innerMostLoopNode) { case (comprehension, loopBodyNode) => + extractComprehensionSpecialVariableNames(comprehension.target).foreach { name => + // For the target names we need to create special scoped variables. + val localNode = nodeBuilder.localNode(name.id, None) + specialTargetLocals.append(localNode) + contextStack.addSpecialVariable(localNode) + } + createForLowering( + comprehension.target, + comprehension.iter, + comprehension.ifs, + Iterable.single(loopBodyNode), + Iterable.empty, + comprehension.is_async, + lineAndColumn + ) } - keywordArguments.foreach { case (keyword: String, argumentNode) => - edgeBuilder.astEdge(argumentNode, callNode, order = argIndex) - edgeBuilder.argumentEdge(argumentNode, callNode, argName = keyword) - argIndex += 1 - } + val returnIdentifierNode = createIdentifierNode(tmpVariableName, Load, lineAndColumn) - callNode - end createStaticCall - - protected def createNAryOperatorCall( - opCodeAndFullName: () => (String, String), - operands: Iterable[NewNode], - lineAndColumn: LineAndColumn - ): NewNode = - - val (operatorCode, methodFullName) = opCodeAndFullName() - val code = - operands.map(operandNode => codeOf(operandNode)).mkString(" " + operatorCode + " ") - val callNode = - nodeBuilder.callNode(code, methodFullName, DispatchTypes.STATIC_DISPATCH, lineAndColumn) - - addAstChildrenAsArguments(callNode, 1, operands) - - callNode - - protected def createBinaryOperatorCall( - lhsNode: NewNode, - opCodeAndFullName: () => (String, String), - rhsNode: NewNode, - lineAndColumn: LineAndColumn - ): NewCall = - val (opCode, opFullName) = opCodeAndFullName() - - val code = codeOf(lhsNode) + " " + opCode + " " + codeOf(rhsNode) - val callNode = - nodeBuilder.callNode(code, opFullName, DispatchTypes.STATIC_DISPATCH, lineAndColumn) - - addAstChildrenAsArguments(callNode, 1, lhsNode, rhsNode) - callNode - - protected def createLiteralOperatorCall( - codeStart: String, - codeEnd: String, - opFullName: String, - lineAndColumn: LineAndColumn, - operands: NewNode* - ): NewCall = - val code = operands.map(codeOf).mkString(codeStart, ", ", codeEnd) - val callNode = - nodeBuilder.callNode(code, opFullName, DispatchTypes.STATIC_DISPATCH, lineAndColumn) - - addAstChildrenAsArguments(callNode, 1, operands) - - callNode - - protected def createStarredUnpackOperatorCall( - unpackOperand: NewNode, - lineAndColumn: LineAndColumn - ): NewNode = - val code = "*" + codeOf(unpackOperand) - val callNode = nodeBuilder.callNode( - code, - ".starredUnpack", - DispatchTypes.STATIC_DISPATCH, - lineAndColumn - ) - - addAstChildrenAsArguments(callNode, 1, unpackOperand) - callNode - - protected def createAssignment( - lhsNode: NewNode, - rhsNode: NewNode, - lineAndColumn: LineAndColumn - ): NewNode = - val code = codeOf(lhsNode) + " = " + codeOf(rhsNode) - val callNode = nodeBuilder.callNode( - code, - Operators.assignment, - DispatchTypes.STATIC_DISPATCH, - lineAndColumn - ) - - addAstChildrenAsArguments(callNode, 1, lhsNode, rhsNode) - // Do not include imports or function pointers - if !codeOf(rhsNode).startsWith("import(") && codeOf(rhsNode) != s"def ${codeOf(lhsNode)}(...)" - then - contextStack.considerAsGlobalVariable(lhsNode) - - callNode - end createAssignment - - protected def createAssignmentToIdentifier( - identifierName: String, - rhsNode: NewNode, - lineAndColumn: LineAndColumn - ): NewNode = - val identifierNode = createIdentifierNode(identifierName, Store, lineAndColumn) - createAssignment(identifierNode, rhsNode, lineAndColumn) - - protected def createAugAssignment( - lhsNode: NewNode, - operatorCode: String, - rhsNode: NewNode, - operatorFullName: String, - lineAndColumn: LineAndColumn - ): NewNode = - val code = codeOf(lhsNode) + " " + operatorCode + " " + codeOf(rhsNode) - val callNode = nodeBuilder.callNode( - code, - operatorFullName, - DispatchTypes.STATIC_DISPATCH, + val blockNode = + createBlock( + containerInitAssignNode :: nestedLoopBlockNode :: returnIdentifierNode :: Nil, lineAndColumn ) - addAstChildrenAsArguments(callNode, 1, lhsNode, rhsNode) - - callNode - - // Always use this method to create an identifier node instead of - // nodeBuilder.identifierNode() directly to avoid missing to add - // the variable reference. - protected def createIdentifierNode( - name: String, - memOp: MemoryOperation, - lineAndColumn: LineAndColumn - ): NewIdentifier = - val identifierNode = nodeBuilder.identifierNode(name, lineAndColumn) - contextStack.addVariableReference(identifierNode, memOp) - identifierNode - - protected def createIndexAccess( - baseNode: NewNode, - indexNode: NewNode, - lineAndColumn: LineAndColumn - ): NewNode = - val code = codeOf(baseNode) + "[" + codeOf(indexNode) + "]" - val indexAccessNode = - nodeBuilder.callNode( - code, - Operators.indexAccess, - DispatchTypes.STATIC_DISPATCH, + addAstChildNodes(blockNode, 1, specialTargetLocals) + + blockNode + end createComprehensionLowering + + // Extracts plain names, starred names and name or starred name elements from tuples and lists. + private def extractComprehensionSpecialVariableNames(target: ast.iexpr): Iterable[ast.Name] = + target match + case name: ast.Name => + name :: Nil + case starred: ast.Starred => + extractComprehensionSpecialVariableNames(starred.value) + case tuple: ast.Tuple => + tuple.elts.flatMap(extractComprehensionSpecialVariableNames) + case list: ast.List => + list.elts.flatMap(extractComprehensionSpecialVariableNames) + case _ => + Nil + + protected def createBlock( + blockElements: Iterable[NewNode], + lineAndColumn: LineAndColumn + ): NewNode = + val blockCode = blockElements.map(codeOf).mkString("\n") + val blockNode = nodeBuilder.blockNode(blockCode, lineAndColumn) + + val orderIndex = new AutoIncIndex(1) + addAstChildNodes(blockNode, orderIndex, blockElements) + + blockNode + + protected def createCall( + receiverNode: NewNode, + name: String, + lineAndColumn: LineAndColumn, + argumentNodes: Iterable[NewNode], + keywordArguments: Iterable[(String, NewNode)] + ): NewCall = + val code = codeOf(receiverNode) + + "(" + + argumentNodes.map(codeOf).mkString(", ") + + (if argumentNodes.nonEmpty && keywordArguments.nonEmpty then ", " else "") + + keywordArguments + .map { case (keyword: String, argNode) => keyword + " = " + codeOf(argNode) } + .mkString(", ") + + ")" + val callNode = + nodeBuilder.callNode(code, name, DispatchTypes.DYNAMIC_DISPATCH, lineAndColumn) + + edgeBuilder.astEdge(receiverNode, callNode, 0) + edgeBuilder.receiverEdge(receiverNode, callNode) + + var index = 1 + argumentNodes.foreach { argumentNode => + edgeBuilder.astEdge(argumentNode, callNode, order = index) + edgeBuilder.argumentEdge(argumentNode, callNode, argIndex = index) + index += 1 + } + + keywordArguments.foreach { case (keyword: String, argumentNode) => + edgeBuilder.astEdge(argumentNode, callNode, order = index) + edgeBuilder.argumentEdge(argumentNode, callNode, argName = keyword) + index += 1 + } + + callNode + end createCall + + protected def createInstanceCall( + receiverNode: NewNode, + instanceNode: NewNode, + name: String, + lineAndColumn: LineAndColumn, + argumentNodes: Iterable[NewNode], + keywordArguments: Iterable[(String, NewNode)] + ): NewCall = + val code = codeOf(receiverNode) + + "(" + + argumentNodes.map(codeOf).mkString(", ") + + (if argumentNodes.nonEmpty && keywordArguments.nonEmpty then ", " else "") + + keywordArguments + .map { case (keyword: String, argNode) => keyword + " = " + codeOf(argNode) } + .mkString(", ") + + ")" + val callNode = + nodeBuilder.callNode(code, name, DispatchTypes.DYNAMIC_DISPATCH, lineAndColumn) + + edgeBuilder.astEdge(receiverNode, callNode, 0) + edgeBuilder.receiverEdge(receiverNode, callNode) + edgeBuilder.astEdge(instanceNode, callNode, 1) + edgeBuilder.argumentEdge(instanceNode, callNode, 0) + + var argIndex = 1 + argumentNodes.foreach { argumentNode => + edgeBuilder.astEdge(argumentNode, callNode, argIndex + 1) + edgeBuilder.argumentEdge(argumentNode, callNode, argIndex) + argIndex += 1 + } + + keywordArguments.foreach { case (keyword: String, argumentNode) => + edgeBuilder.astEdge(argumentNode, callNode, order = argIndex + 1) + edgeBuilder.argumentEdge(argumentNode, callNode, argName = keyword) + argIndex += 1 + } + + callNode + end createInstanceCall + + // NOTE if xMayHaveSideEffects == false, function x must return a distinct + // tree/node for each invocation!!! + // Otherwise the same tree/node may get placed in different places of the AST + // which is invalid and in this concrete case here triggered setting the + // argumentIndex twice once for x being receiver and once for x being the + // instance. + // If x may have side effects we lower as follows: x.y() => + // { + // tmp = x + // CALL(recv = tmp.y, inst = tmp, args=) + // } + protected def createXDotYCall( + x: () => NewNode, + y: String, + xMayHaveSideEffects: Boolean, + lineAndColumn: LineAndColumn, + argumentNodes: Iterable[NewNode], + keywordArguments: Iterable[(String, NewNode)] + ): NewNode = + if xMayHaveSideEffects then + val tmpVarName = getUnusedName() + val tmpAssignCall = createAssignmentToIdentifier(tmpVarName, x(), lineAndColumn) + val receiverNode = + createFieldAccess( + createIdentifierNode(tmpVarName, Load, lineAndColumn), + y, lineAndColumn ) - - addAstChildrenAsArguments(indexAccessNode, 1, baseNode, indexNode) - - indexAccessNode - - protected def createIndexAccessChain( - rootNode: NewNode, - accessChain: List[Int], - lineAndColumn: LineAndColumn - ): NewNode = - accessChain match - case accessIndex :: tail => - val baseNode = createIndexAccessChain(rootNode, tail, lineAndColumn) - val indexNode = nodeBuilder.numberLiteralNode(accessIndex, lineAndColumn) - - createIndexAccess(baseNode, indexNode, lineAndColumn) - case Nil => - rootNode - - protected def createFieldAccess( - baseNode: NewNode, - fieldName: String, - lineAndColumn: LineAndColumn - ): NewCall = - val fieldIdNode = nodeBuilder.fieldIdentifierNode(fieldName, lineAndColumn) - - val code = codeOf(baseNode) + "." + codeOf(fieldIdNode) - val callNode = nodeBuilder.callNode( + val instanceNode = createIdentifierNode(tmpVarName, Load, lineAndColumn) + val instanceCallNode = + createInstanceCall( + receiverNode, + instanceNode, + y, + lineAndColumn, + argumentNodes, + keywordArguments + ) + createBlock(tmpAssignCall :: instanceCallNode :: Nil, lineAndColumn) + else + val receiverNode = createFieldAccess(x(), y, lineAndColumn) + createInstanceCall(receiverNode, x(), y, lineAndColumn, argumentNodes, keywordArguments) + + // NOTE: The argument indicies start from 0! + protected def createStaticCall( + name: String, + methodFullName: String, + lineAndColumn: LineAndColumn, + argumentNodes: Iterable[NewNode], + keywordArguments: Iterable[(String, NewNode)] + ): NewNode = + val code = name + + "(" + + argumentNodes.map(codeOf).mkString(", ") + + (if argumentNodes.nonEmpty && keywordArguments.nonEmpty then ", " else "") + + keywordArguments + .map { case (keyword: String, argNode) => keyword + " = " + codeOf(argNode) } + .mkString(", ") + + ")" + val callNode = + nodeBuilder.callNode(code, methodFullName, DispatchTypes.STATIC_DISPATCH, lineAndColumn) + + var argIndex = 0 + argumentNodes.foreach { argumentNode => + edgeBuilder.astEdge(argumentNode, callNode, argIndex) + edgeBuilder.argumentEdge(argumentNode, callNode, argIndex) + argIndex += 1 + } + + keywordArguments.foreach { case (keyword: String, argumentNode) => + edgeBuilder.astEdge(argumentNode, callNode, order = argIndex) + edgeBuilder.argumentEdge(argumentNode, callNode, argName = keyword) + argIndex += 1 + } + + callNode + end createStaticCall + + protected def createNAryOperatorCall( + opCodeAndFullName: () => (String, String), + operands: Iterable[NewNode], + lineAndColumn: LineAndColumn + ): NewNode = + + val (operatorCode, methodFullName) = opCodeAndFullName() + val code = + operands.map(operandNode => codeOf(operandNode)).mkString(" " + operatorCode + " ") + val callNode = + nodeBuilder.callNode(code, methodFullName, DispatchTypes.STATIC_DISPATCH, lineAndColumn) + + addAstChildrenAsArguments(callNode, 1, operands) + + callNode + + protected def createBinaryOperatorCall( + lhsNode: NewNode, + opCodeAndFullName: () => (String, String), + rhsNode: NewNode, + lineAndColumn: LineAndColumn + ): NewCall = + val (opCode, opFullName) = opCodeAndFullName() + + val code = codeOf(lhsNode) + " " + opCode + " " + codeOf(rhsNode) + val callNode = + nodeBuilder.callNode(code, opFullName, DispatchTypes.STATIC_DISPATCH, lineAndColumn) + + addAstChildrenAsArguments(callNode, 1, lhsNode, rhsNode) + callNode + + protected def createLiteralOperatorCall( + codeStart: String, + codeEnd: String, + opFullName: String, + lineAndColumn: LineAndColumn, + operands: NewNode* + ): NewCall = + val code = operands.map(codeOf).mkString(codeStart, ", ", codeEnd) + val callNode = + nodeBuilder.callNode(code, opFullName, DispatchTypes.STATIC_DISPATCH, lineAndColumn) + + addAstChildrenAsArguments(callNode, 1, operands) + + callNode + + protected def createStarredUnpackOperatorCall( + unpackOperand: NewNode, + lineAndColumn: LineAndColumn + ): NewNode = + val code = "*" + codeOf(unpackOperand) + val callNode = nodeBuilder.callNode( + code, + ".starredUnpack", + DispatchTypes.STATIC_DISPATCH, + lineAndColumn + ) + + addAstChildrenAsArguments(callNode, 1, unpackOperand) + callNode + + protected def createAssignment( + lhsNode: NewNode, + rhsNode: NewNode, + lineAndColumn: LineAndColumn + ): NewNode = + val code = codeOf(lhsNode) + " = " + codeOf(rhsNode) + val callNode = nodeBuilder.callNode( + code, + Operators.assignment, + DispatchTypes.STATIC_DISPATCH, + lineAndColumn + ) + + addAstChildrenAsArguments(callNode, 1, lhsNode, rhsNode) + // Do not include imports or function pointers + if !codeOf(rhsNode).startsWith("import(") && codeOf(rhsNode) != s"def ${codeOf(lhsNode)}(...)" + then + contextStack.considerAsGlobalVariable(lhsNode) + + callNode + end createAssignment + + protected def createAssignmentToIdentifier( + identifierName: String, + rhsNode: NewNode, + lineAndColumn: LineAndColumn + ): NewNode = + val identifierNode = createIdentifierNode(identifierName, Store, lineAndColumn) + createAssignment(identifierNode, rhsNode, lineAndColumn) + + protected def createAugAssignment( + lhsNode: NewNode, + operatorCode: String, + rhsNode: NewNode, + operatorFullName: String, + lineAndColumn: LineAndColumn + ): NewNode = + val code = codeOf(lhsNode) + " " + operatorCode + " " + codeOf(rhsNode) + val callNode = nodeBuilder.callNode( + code, + operatorFullName, + DispatchTypes.STATIC_DISPATCH, + lineAndColumn + ) + + addAstChildrenAsArguments(callNode, 1, lhsNode, rhsNode) + + callNode + + // Always use this method to create an identifier node instead of + // nodeBuilder.identifierNode() directly to avoid missing to add + // the variable reference. + protected def createIdentifierNode( + name: String, + memOp: MemoryOperation, + lineAndColumn: LineAndColumn + ): NewIdentifier = + val identifierNode = nodeBuilder.identifierNode(name, lineAndColumn) + contextStack.addVariableReference(identifierNode, memOp) + identifierNode + + protected def createIndexAccess( + baseNode: NewNode, + indexNode: NewNode, + lineAndColumn: LineAndColumn + ): NewNode = + val code = codeOf(baseNode) + "[" + codeOf(indexNode) + "]" + val indexAccessNode = + nodeBuilder.callNode( code, - Operators.fieldAccess, + Operators.indexAccess, DispatchTypes.STATIC_DISPATCH, lineAndColumn ) - addAstChildrenAsArguments(callNode, 1, baseNode, fieldIdNode) - callNode - - protected def createTypeRef( - typeName: String, - typeFullName: String, - lineAndColumn: LineAndColumn - ): NewTypeRef = - nodeBuilder.typeRefNode("class " + typeName + "(...)", typeFullName, lineAndColumn) - - protected def createBinding(methodNode: NewMethod, typeDeclNode: NewTypeDecl): NewBinding = - val bindingNode = nodeBuilder.bindingNode() - edgeBuilder.bindsEdge(bindingNode, typeDeclNode) - edgeBuilder.refEdge(methodNode, bindingNode) - - bindingNode - - protected def createReturn( - returnExprOption: Option[NewNode], - codeOption: Option[String], - lineAndColumn: LineAndColumn - ): NewReturn = - val code = codeOption.getOrElse { - returnExprOption match - case Some(returnExpr) => - "return " + codeOf(returnExpr) - case None => - "return" - } - val returnNode = nodeBuilder.returnNode(code, lineAndColumn) - returnExprOption.foreach { returnExpr => - addAstChildrenAsArguments(returnNode, 1, returnExpr) - } - - returnNode - - protected def addAstChildNodes( - parentNode: NewNode, - startIndex: AutoIncIndex, - childNodes: Iterable[NewNode] - ): Unit = - childNodes.foreach { childNode => - val orderIndex = startIndex.getAndInc - edgeBuilder.astEdge(childNode, parentNode, orderIndex) - } - - protected def addAstChildNodes( - parentNode: NewNode, - startIndex: Int, - childNodes: Iterable[NewNode] - ): Unit = - addAstChildNodes(parentNode, new AutoIncIndex(startIndex), childNodes) - - protected def addAstChildNodes( - parentNode: NewNode, - startIndex: AutoIncIndex, - childNodes: NewNode* - ): Unit = - addAstChildNodes(parentNode, startIndex, childNodes) - - protected def addAstChildNodes( - parentNode: NewNode, - startIndex: Int, - childNodes: NewNode* - ): Unit = - addAstChildNodes(parentNode, new AutoIncIndex(startIndex), childNodes) - - protected def addAstChildrenAsArguments( - parentNode: NewNode, - startIndex: AutoIncIndex, - childNodes: Iterable[NewNode] - ): Unit = - childNodes.foreach { childNode => - val orderAndArgIndex = startIndex.getAndInc - edgeBuilder.astEdge(childNode, parentNode, orderAndArgIndex) - edgeBuilder.argumentEdge(childNode, parentNode, orderAndArgIndex) - } - - protected def addAstChildrenAsArguments( - parentNode: NewNode, - startIndex: Int, - childNodes: Iterable[NewNode] - ): Unit = - addAstChildrenAsArguments(parentNode, new AutoIncIndex(startIndex), childNodes) - - protected def addAstChildrenAsArguments( - parentNode: NewNode, - startIndex: AutoIncIndex, - childNodes: NewNode* - ): Unit = - addAstChildrenAsArguments(parentNode, startIndex, childNodes) - - protected def addAstChildrenAsArguments( - parentNode: NewNode, - startIndex: Int, - childNodes: NewNode* - ): Unit = - addAstChildrenAsArguments(parentNode, new AutoIncIndex(startIndex), childNodes) + addAstChildrenAsArguments(indexAccessNode, 1, baseNode, indexNode) + + indexAccessNode + + protected def createIndexAccessChain( + rootNode: NewNode, + accessChain: List[Int], + lineAndColumn: LineAndColumn + ): NewNode = + accessChain match + case accessIndex :: tail => + val baseNode = createIndexAccessChain(rootNode, tail, lineAndColumn) + val indexNode = nodeBuilder.numberLiteralNode(accessIndex, lineAndColumn) + + createIndexAccess(baseNode, indexNode, lineAndColumn) + case Nil => + rootNode + + protected def createFieldAccess( + baseNode: NewNode, + fieldName: String, + lineAndColumn: LineAndColumn + ): NewCall = + val fieldIdNode = nodeBuilder.fieldIdentifierNode(fieldName, lineAndColumn) + + val code = codeOf(baseNode) + "." + codeOf(fieldIdNode) + val callNode = nodeBuilder.callNode( + code, + Operators.fieldAccess, + DispatchTypes.STATIC_DISPATCH, + lineAndColumn + ) + + addAstChildrenAsArguments(callNode, 1, baseNode, fieldIdNode) + callNode + + protected def createTypeRef( + typeName: String, + typeFullName: String, + lineAndColumn: LineAndColumn + ): NewTypeRef = + nodeBuilder.typeRefNode("class " + typeName + "(...)", typeFullName, lineAndColumn) + + protected def createBinding(methodNode: NewMethod, typeDeclNode: NewTypeDecl): NewBinding = + val bindingNode = nodeBuilder.bindingNode() + edgeBuilder.bindsEdge(bindingNode, typeDeclNode) + edgeBuilder.refEdge(methodNode, bindingNode) + + bindingNode + + protected def createReturn( + returnExprOption: Option[NewNode], + codeOption: Option[String], + lineAndColumn: LineAndColumn + ): NewReturn = + val code = codeOption.getOrElse { + returnExprOption match + case Some(returnExpr) => + "return " + codeOf(returnExpr) + case None => + "return" + } + val returnNode = nodeBuilder.returnNode(code, lineAndColumn) + returnExprOption.foreach { returnExpr => + addAstChildrenAsArguments(returnNode, 1, returnExpr) + } + + returnNode + + protected def addAstChildNodes( + parentNode: NewNode, + startIndex: AutoIncIndex, + childNodes: Iterable[NewNode] + ): Unit = + childNodes.foreach { childNode => + val orderIndex = startIndex.getAndInc + edgeBuilder.astEdge(childNode, parentNode, orderIndex) + } + + protected def addAstChildNodes( + parentNode: NewNode, + startIndex: Int, + childNodes: Iterable[NewNode] + ): Unit = + addAstChildNodes(parentNode, new AutoIncIndex(startIndex), childNodes) + + protected def addAstChildNodes( + parentNode: NewNode, + startIndex: AutoIncIndex, + childNodes: NewNode* + ): Unit = + addAstChildNodes(parentNode, startIndex, childNodes) + + protected def addAstChildNodes( + parentNode: NewNode, + startIndex: Int, + childNodes: NewNode* + ): Unit = + addAstChildNodes(parentNode, new AutoIncIndex(startIndex), childNodes) + + protected def addAstChildrenAsArguments( + parentNode: NewNode, + startIndex: AutoIncIndex, + childNodes: Iterable[NewNode] + ): Unit = + childNodes.foreach { childNode => + val orderAndArgIndex = startIndex.getAndInc + edgeBuilder.astEdge(childNode, parentNode, orderAndArgIndex) + edgeBuilder.argumentEdge(childNode, parentNode, orderAndArgIndex) + } + + protected def addAstChildrenAsArguments( + parentNode: NewNode, + startIndex: Int, + childNodes: Iterable[NewNode] + ): Unit = + addAstChildrenAsArguments(parentNode, new AutoIncIndex(startIndex), childNodes) + + protected def addAstChildrenAsArguments( + parentNode: NewNode, + startIndex: AutoIncIndex, + childNodes: NewNode* + ): Unit = + addAstChildrenAsArguments(parentNode, startIndex, childNodes) + + protected def addAstChildrenAsArguments( + parentNode: NewNode, + startIndex: Int, + childNodes: NewNode* + ): Unit = + addAstChildrenAsArguments(parentNode, new AutoIncIndex(startIndex), childNodes) end PythonAstVisitorHelpers diff --git a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/PythonInheritanceNamePass.scala b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/PythonInheritanceNamePass.scala index 473c4f70..07c0568d 100644 --- a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/PythonInheritanceNamePass.scala +++ b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/PythonInheritanceNamePass.scala @@ -8,5 +8,5 @@ import io.shiftleft.codepropertygraph.Cpg */ class PythonInheritanceNamePass(cpg: Cpg) extends XInheritanceFullNamePass(cpg): - override val moduleName: String = "" - override val fileExt: String = ".py" + override val moduleName: String = "" + override val fileExt: String = ".py" diff --git a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/PythonTypeHintCallLinker.scala b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/PythonTypeHintCallLinker.scala index 10830407..9ee68105 100644 --- a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/PythonTypeHintCallLinker.scala +++ b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/PythonTypeHintCallLinker.scala @@ -8,22 +8,22 @@ import io.shiftleft.semanticcpg.language.* class PythonTypeHintCallLinker(cpg: Cpg) extends XTypeHintCallLinker(cpg): - override def calls: Iterator[Call] = super.calls.nameNot("^(import).*") + override def calls: Iterator[Call] = super.calls.nameNot("^(import).*") - override def calleeNames(c: Call): Seq[String] = super.calleeNames(c).map { - // Python call from a type - case typ if typ.split("\\.").lastOption.exists(_.charAt(0).isUpper) => s"$typ.__init__" - // Python call from a function pointer - case typ => typ - } + override def calleeNames(c: Call): Seq[String] = super.calleeNames(c).map { + // Python call from a type + case typ if typ.split("\\.").lastOption.exists(_.charAt(0).isUpper) => s"$typ.__init__" + // Python call from a function pointer + case typ => typ + } - override def setCallees(call: Call, methodNames: Seq[String], builder: DiffGraphBuilder): Unit = - if methodNames.sizeIs == 1 then - super.setCallees(call, methodNames, builder) - else if methodNames.sizeIs > 1 then - val nonDummyMethodNames = - methodNames.filterNot(x => - isDummyType(x) || x.startsWith(PythonAstVisitor.builtinPrefix + "None") - ) - super.setCallees(call, nonDummyMethodNames, builder) + override def setCallees(call: Call, methodNames: Seq[String], builder: DiffGraphBuilder): Unit = + if methodNames.sizeIs == 1 then + super.setCallees(call, methodNames, builder) + else if methodNames.sizeIs > 1 then + val nonDummyMethodNames = + methodNames.filterNot(x => + isDummyType(x) || x.startsWith(PythonAstVisitor.builtinPrefix + "None") + ) + super.setCallees(call, nonDummyMethodNames, builder) end PythonTypeHintCallLinker diff --git a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/PythonTypeRecovery.scala b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/PythonTypeRecovery.scala index 187aff15..465ac74d 100644 --- a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/PythonTypeRecovery.scala +++ b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/PythonTypeRecovery.scala @@ -12,22 +12,22 @@ import overflowdb.BatchedUpdate.DiffGraphBuilder class PythonTypeRecoveryPass(cpg: Cpg, config: XTypeRecoveryConfig = XTypeRecoveryConfig()) extends XTypeRecoveryPass[File](cpg, config): - override protected def generateRecoveryPass(state: XTypeRecoveryState): XTypeRecovery[File] = - new PythonTypeRecovery(cpg, state) + override protected def generateRecoveryPass(state: XTypeRecoveryState): XTypeRecovery[File] = + new PythonTypeRecovery(cpg, state) private class PythonTypeRecovery(cpg: Cpg, state: XTypeRecoveryState) extends XTypeRecovery[File](cpg, state): - override def compilationUnit: Iterator[File] = cpg.file.iterator + override def compilationUnit: Iterator[File] = cpg.file.iterator - override def generateRecoveryForCompilationUnitTask( - unit: File, - builder: DiffGraphBuilder - ): RecoverForXCompilationUnit[File] = - val newConfig = state.config.copy(enabledDummyTypes = - state.isFinalIteration && state.config.enabledDummyTypes - ) - new RecoverForPythonFile(cpg, unit, builder, state.copy(config = newConfig)) + override def generateRecoveryForCompilationUnitTask( + unit: File, + builder: DiffGraphBuilder + ): RecoverForXCompilationUnit[File] = + val newConfig = state.config.copy(enabledDummyTypes = + state.isFinalIteration && state.config.enabledDummyTypes + ) + new RecoverForPythonFile(cpg, unit, builder, state.copy(config = newConfig)) /** Performs type recovery from the root of a compilation unit level */ @@ -38,219 +38,219 @@ private class RecoverForPythonFile( state: XTypeRecoveryState ) extends RecoverForXCompilationUnit[File](cpg, cu, builder, state): - override val symbolTable: SymbolTable[LocalKey] = - new SymbolTable[LocalKey](fromNodeToLocalPythonKey) + override val symbolTable: SymbolTable[LocalKey] = + new SymbolTable[LocalKey](fromNodeToLocalPythonKey) - override def visitImport(i: Import): Unit = - if i.importedAs.isDefined && i.importedEntity.isDefined then - import io.appthreat.x2cpg.passes.frontend.ImportsPass.* + override def visitImport(i: Import): Unit = + if i.importedAs.isDefined && i.importedEntity.isDefined then + import io.appthreat.x2cpg.passes.frontend.ImportsPass.* - val entityName = i.importedAs.get - i.call.tag.flatMap(ResolvedImport.tagToResolvedImport).foreach { - case ResolvedMethod(fullName, alias, receiver, _) => - symbolTable.put(CallAlias(alias, receiver), fullName) - case ResolvedTypeDecl(fullName, _) => - symbolTable.put(LocalVar(entityName), fullName) - case ResolvedMember(basePath, memberName, _) => - val memberTypes = cpg.typeDecl - .fullNameExact(basePath) - .member - .nameExact(memberName) - .flatMap(m => m.typeFullName +: m.dynamicTypeHintFullName) - .filterNot(_ == "ANY") - .toSet - symbolTable.put(LocalVar(entityName), memberTypes) - case UnknownMethod(fullName, alias, receiver, _) => - symbolTable.put(CallAlias(alias, receiver), fullName) - case UnknownTypeDecl(fullName, _) => - symbolTable.put(LocalVar(entityName), fullName) - case UnknownImport(path, _) => - symbolTable.put(CallAlias(entityName), path) - symbolTable.put(LocalVar(entityName), path) - } - - override def visitAssignments(a: OpNodes.Assignment): Set[String] = - a.argumentOut.l match - case List(i: Identifier, c: Call) if c.name.isBlank && c.signature.isBlank => - // This is usually some decorator wrapper - c.argument.isMethodRef.headOption match - case Some(mRef) => visitIdentifierAssignedToMethodRef(i, mRef) - case None => super.visitAssignments(a) - case _ => super.visitAssignments(a) + val entityName = i.importedAs.get + i.call.tag.flatMap(ResolvedImport.tagToResolvedImport).foreach { + case ResolvedMethod(fullName, alias, receiver, _) => + symbolTable.put(CallAlias(alias, receiver), fullName) + case ResolvedTypeDecl(fullName, _) => + symbolTable.put(LocalVar(entityName), fullName) + case ResolvedMember(basePath, memberName, _) => + val memberTypes = cpg.typeDecl + .fullNameExact(basePath) + .member + .nameExact(memberName) + .flatMap(m => m.typeFullName +: m.dynamicTypeHintFullName) + .filterNot(_ == "ANY") + .toSet + symbolTable.put(LocalVar(entityName), memberTypes) + case UnknownMethod(fullName, alias, receiver, _) => + symbolTable.put(CallAlias(alias, receiver), fullName) + case UnknownTypeDecl(fullName, _) => + symbolTable.put(LocalVar(entityName), fullName) + case UnknownImport(path, _) => + symbolTable.put(CallAlias(entityName), path) + symbolTable.put(LocalVar(entityName), path) + } - /** Determines if a function call is a constructor by following the heuristic that Python - * classes are typically camel-case and start with an upper-case character. - */ - override def isConstructor(c: Call): Boolean = - isConstructor(c.name) && c.code.endsWith(")") + override def visitAssignments(a: OpNodes.Assignment): Set[String] = + a.argumentOut.l match + case List(i: Identifier, c: Call) if c.name.isBlank && c.signature.isBlank => + // This is usually some decorator wrapper + c.argument.isMethodRef.headOption match + case Some(mRef) => visitIdentifierAssignedToMethodRef(i, mRef) + case None => super.visitAssignments(a) + case _ => super.visitAssignments(a) - /** If the parent method is module then it can be used as a field. - */ - override def isField(i: Identifier): Boolean = - state.isFieldCache.getOrElseUpdate( - i.id(), - i.method.name.matches("(|__init__)") || super.isField(i) - ) + /** Determines if a function call is a constructor by following the heuristic that Python classes + * are typically camel-case and start with an upper-case character. + */ + override def isConstructor(c: Call): Boolean = + isConstructor(c.name) && c.code.endsWith(")") - override def visitIdentifierAssignedToOperator( - i: Identifier, - c: Call, - operation: String - ): Set[String] = - operation match - case ".listLiteral" => - associateTypes(i, Set(s"${PythonAstVisitor.builtinPrefix}list")) - case ".tupleLiteral" => - associateTypes(i, Set(s"${PythonAstVisitor.builtinPrefix}tuple")) - case ".dictLiteral" => - associateTypes(i, Set(s"${PythonAstVisitor.builtinPrefix}dict")) - case ".setLiteral" => - associateTypes(i, Set(s"${PythonAstVisitor.builtinPrefix}set")) - case Operators.conditional => - associateTypes(i, Set(s"${PythonAstVisitor.builtinPrefix}bool")) - case Operators.indexAccess => - c.argument.argumentIndex(1).isCall.foreach(setCallMethodFullNameFromBase) - visitIdentifierAssignedToIndexAccess(i, c) - case _ => super.visitIdentifierAssignedToOperator(i, c, operation) + /** If the parent method is module then it can be used as a field. + */ + override def isField(i: Identifier): Boolean = + state.isFieldCache.getOrElseUpdate( + i.id(), + i.method.name.matches("(|__init__)") || super.isField(i) + ) - override def visitIdentifierAssignedToConstructor(i: Identifier, c: Call): Set[String] = - val constructorPaths = symbolTable.get(c).map(_.stripSuffix(s"${pathSep}__init__")) - associateTypes(i, constructorPaths) + override def visitIdentifierAssignedToOperator( + i: Identifier, + c: Call, + operation: String + ): Set[String] = + operation match + case ".listLiteral" => + associateTypes(i, Set(s"${PythonAstVisitor.builtinPrefix}list")) + case ".tupleLiteral" => + associateTypes(i, Set(s"${PythonAstVisitor.builtinPrefix}tuple")) + case ".dictLiteral" => + associateTypes(i, Set(s"${PythonAstVisitor.builtinPrefix}dict")) + case ".setLiteral" => + associateTypes(i, Set(s"${PythonAstVisitor.builtinPrefix}set")) + case Operators.conditional => + associateTypes(i, Set(s"${PythonAstVisitor.builtinPrefix}bool")) + case Operators.indexAccess => + c.argument.argumentIndex(1).isCall.foreach(setCallMethodFullNameFromBase) + visitIdentifierAssignedToIndexAccess(i, c) + case _ => super.visitIdentifierAssignedToOperator(i, c, operation) - override def visitIdentifierAssignedToCall(i: Identifier, c: Call): Set[String] = - // Ignore legacy import representation - if c.name.equals("import") then Set.empty - // Stop custom annotation representation from hitting superclass - else if c.name.isBlank then Set.empty - else super.visitIdentifierAssignedToCall(i, c) + override def visitIdentifierAssignedToConstructor(i: Identifier, c: Call): Set[String] = + val constructorPaths = symbolTable.get(c).map(_.stripSuffix(s"${pathSep}__init__")) + associateTypes(i, constructorPaths) - override def visitIdentifierAssignedToFieldLoad(i: Identifier, fa: FieldAccess): Set[String] = - val fieldParents = getFieldParents(fa) - fa.astChildren.l match - case List(base: Identifier, fi: FieldIdentifier) - if base.name.equals("self") && fieldParents.nonEmpty => - val referencedFields = cpg.typeDecl.fullNameExact( - fieldParents.toSeq* - ).member.nameExact(fi.canonicalName) - val globalTypes = - referencedFields.flatMap(m => - m.typeFullName +: m.dynamicTypeHintFullName - ).filterNot(_ == Constants.ANY).toSet - associateTypes(i, globalTypes) - case _ => super.visitIdentifierAssignedToFieldLoad(i, fa) + override def visitIdentifierAssignedToCall(i: Identifier, c: Call): Set[String] = + // Ignore legacy import representation + if c.name.equals("import") then Set.empty + // Stop custom annotation representation from hitting superclass + else if c.name.isBlank then Set.empty + else super.visitIdentifierAssignedToCall(i, c) - override def getFieldParents(fa: FieldAccess): Set[String] = - if fa.method.name == "" then - Set(fa.method.fullName) - else if fa.method.typeDecl.nonEmpty then - val parentTypes = fa.method.typeDecl.fullName.toSet - val baseTypeFullNames = - cpg.typeDecl.fullNameExact(parentTypes.toSeq*).inheritsFromTypeFullName.toSet - (parentTypes ++ baseTypeFullNames).filterNot(_.matches("(?i)(any|object)")) - else - super.getFieldParents(fa) + override def visitIdentifierAssignedToFieldLoad(i: Identifier, fa: FieldAccess): Set[String] = + val fieldParents = getFieldParents(fa) + fa.astChildren.l match + case List(base: Identifier, fi: FieldIdentifier) + if base.name.equals("self") && fieldParents.nonEmpty => + val referencedFields = cpg.typeDecl.fullNameExact( + fieldParents.toSeq* + ).member.nameExact(fi.canonicalName) + val globalTypes = + referencedFields.flatMap(m => + m.typeFullName +: m.dynamicTypeHintFullName + ).filterNot(_ == Constants.ANY).toSet + associateTypes(i, globalTypes) + case _ => super.visitIdentifierAssignedToFieldLoad(i, fa) - override def getLiteralType(l: Literal): Set[String] = - (l.code match - case code if code.toIntOption.isDefined => Some(s"${PythonAstVisitor.builtinPrefix}int") - case code if code.toDoubleOption.isDefined => - Some(s"${PythonAstVisitor.builtinPrefix}float") - case code if "True".equals(code) || "False".equals(code) => - Some(s"${PythonAstVisitor.builtinPrefix}bool") - case code if code.equals("None") => Some(s"${PythonAstVisitor.builtinPrefix}None") - case code if isPyString(code) => Some(s"${PythonAstVisitor.builtinPrefix}str") - case _ => None - ).toSet + override def getFieldParents(fa: FieldAccess): Set[String] = + if fa.method.name == "" then + Set(fa.method.fullName) + else if fa.method.typeDecl.nonEmpty then + val parentTypes = fa.method.typeDecl.fullName.toSet + val baseTypeFullNames = + cpg.typeDecl.fullNameExact(parentTypes.toSeq*).inheritsFromTypeFullName.toSet + (parentTypes ++ baseTypeFullNames).filterNot(_.matches("(?i)(any|object)")) + else + super.getFieldParents(fa) - private def isPyString(s: String): Boolean = - (s.startsWith("\"") || s.startsWith("'")) && (s.endsWith("\"") || s.endsWith("'")) + override def getLiteralType(l: Literal): Set[String] = + (l.code match + case code if code.toIntOption.isDefined => Some(s"${PythonAstVisitor.builtinPrefix}int") + case code if code.toDoubleOption.isDefined => + Some(s"${PythonAstVisitor.builtinPrefix}float") + case code if "True".equals(code) || "False".equals(code) => + Some(s"${PythonAstVisitor.builtinPrefix}bool") + case code if code.equals("None") => Some(s"${PythonAstVisitor.builtinPrefix}None") + case code if isPyString(code) => Some(s"${PythonAstVisitor.builtinPrefix}str") + case _ => None + ).toSet - override def createCallFromIdentifierTypeFullName( - typeFullName: String, - callName: String - ): String = - lazy val tName = typeFullName.split("\\.").lastOption.getOrElse(typeFullName) - typeFullName match - case t if t.matches(".*(<\\w+>)$") => - super.createCallFromIdentifierTypeFullName(typeFullName, callName) - case t if t.matches(".*\\.<(member|returnValue|indexAccess)>(\\(.*\\))?") => - super.createCallFromIdentifierTypeFullName(typeFullName, callName) - case t if isConstructor(tName) => - Seq(t, callName).mkString(pathSep.toString) - case _ => super.createCallFromIdentifierTypeFullName(typeFullName, callName) + private def isPyString(s: String): Boolean = + (s.startsWith("\"") || s.startsWith("'")) && (s.endsWith("\"") || s.endsWith("'")) - override protected def isConstructor(name: String): Boolean = - name.nonEmpty && name.charAt(0).isUpper + override def createCallFromIdentifierTypeFullName( + typeFullName: String, + callName: String + ): String = + lazy val tName = typeFullName.split("\\.").lastOption.getOrElse(typeFullName) + typeFullName match + case t if t.matches(".*(<\\w+>)$") => + super.createCallFromIdentifierTypeFullName(typeFullName, callName) + case t if t.matches(".*\\.<(member|returnValue|indexAccess)>(\\(.*\\))?") => + super.createCallFromIdentifierTypeFullName(typeFullName, callName) + case t if isConstructor(tName) => + Seq(t, callName).mkString(pathSep.toString) + case _ => super.createCallFromIdentifierTypeFullName(typeFullName, callName) - override def prepopulateSymbolTable(): Unit = - cu.ast.isMethodRef.where( - _.astSiblings.isIdentifier.nameExact("classmethod") - ).referencedMethod.foreach { - classMethod => - classMethod.parameter - .nameExact("cls") - .foreach { cls => - val clsPath = classMethod.typeDecl.fullName.toSet - symbolTable.put(LocalVar(cls.name), clsPath) - if cls.typeFullName == "ANY" then - builder.setNodeProperty( - cls, - PropertyNames.DYNAMIC_TYPE_HINT_FULL_NAME, - clsPath.toSeq - ) - } - } - super.prepopulateSymbolTable() - end prepopulateSymbolTable + override protected def isConstructor(name: String): Boolean = + name.nonEmpty && name.charAt(0).isUpper - override protected def postSetTypeInformation(): Unit = - cu.typeDecl - .map(t => - t -> t.inheritsFromTypeFullName.partition(itf => - symbolTable.contains(LocalVar(itf)) - ) - ) - .foreach { case (t, (identifierTypes, otherTypes)) => - val existingTypes = (identifierTypes ++ otherTypes).distinct - val resolvedTypes = identifierTypes.map(LocalVar.apply).flatMap(symbolTable.get) - if existingTypes != resolvedTypes && resolvedTypes.nonEmpty then - state.changesWereMade.compareAndExchange(false, true) + override def prepopulateSymbolTable(): Unit = + cu.ast.isMethodRef.where( + _.astSiblings.isIdentifier.nameExact("classmethod") + ).referencedMethod.foreach { + classMethod => + classMethod.parameter + .nameExact("cls") + .foreach { cls => + val clsPath = classMethod.typeDecl.fullName.toSet + symbolTable.put(LocalVar(cls.name), clsPath) + if cls.typeFullName == "ANY" then builder.setNodeProperty( - t, - PropertyNames.INHERITS_FROM_TYPE_FULL_NAME, - resolvedTypes + cls, + PropertyNames.DYNAMIC_TYPE_HINT_FULL_NAME, + clsPath.toSeq ) - } + } + } + super.prepopulateSymbolTable() + end prepopulateSymbolTable + + override protected def postSetTypeInformation(): Unit = + cu.typeDecl + .map(t => + t -> t.inheritsFromTypeFullName.partition(itf => + symbolTable.contains(LocalVar(itf)) + ) + ) + .foreach { case (t, (identifierTypes, otherTypes)) => + val existingTypes = (identifierTypes ++ otherTypes).distinct + val resolvedTypes = identifierTypes.map(LocalVar.apply).flatMap(symbolTable.get) + if existingTypes != resolvedTypes && resolvedTypes.nonEmpty then + state.changesWereMade.compareAndExchange(false, true) + builder.setNodeProperty( + t, + PropertyNames.INHERITS_FROM_TYPE_FULL_NAME, + resolvedTypes + ) + } - override protected def visitIdentifierAssignedToTypeRef( - i: Identifier, - t: TypeRef, - rec: Option[String] - ): Set[String] = - t.typ.referencedTypeDecl - .map(_.fullName.stripSuffix("")) - .map(td => symbolTable.append(CallAlias(i.name, rec), Set(td))) - .headOption - .getOrElse(super.visitIdentifierAssignedToTypeRef(i, t, rec)) + override protected def visitIdentifierAssignedToTypeRef( + i: Identifier, + t: TypeRef, + rec: Option[String] + ): Set[String] = + t.typ.referencedTypeDecl + .map(_.fullName.stripSuffix("")) + .map(td => symbolTable.append(CallAlias(i.name, rec), Set(td))) + .headOption + .getOrElse(super.visitIdentifierAssignedToTypeRef(i, t, rec)) - override protected def getIndexAccessTypes(ia: Call): Set[String] = - ia.argument.argumentIndex(1).isCall.headOption match - case Some(c) => - getTypesFromCall(c).map(x => s"$x$pathSep${XTypeRecovery.DummyIndexAccess}") - case _ => super.getIndexAccessTypes(ia) + override protected def getIndexAccessTypes(ia: Call): Set[String] = + ia.argument.argumentIndex(1).isCall.headOption match + case Some(c) => + getTypesFromCall(c).map(x => s"$x$pathSep${XTypeRecovery.DummyIndexAccess}") + case _ => super.getIndexAccessTypes(ia) - override def getTypesFromCall(c: Call): Set[String] = c.name match - case ".listLiteral" => Set(s"${PythonAstVisitor.builtinPrefix}list") - case ".tupleLiteral" => Set(s"${PythonAstVisitor.builtinPrefix}tuple") - case ".dictLiteral" => Set(s"${PythonAstVisitor.builtinPrefix}dict") - case ".setLiteral" => Set(s"${PythonAstVisitor.builtinPrefix}set") - case _ => super.getTypesFromCall(c) + override def getTypesFromCall(c: Call): Set[String] = c.name match + case ".listLiteral" => Set(s"${PythonAstVisitor.builtinPrefix}list") + case ".tupleLiteral" => Set(s"${PythonAstVisitor.builtinPrefix}tuple") + case ".dictLiteral" => Set(s"${PythonAstVisitor.builtinPrefix}dict") + case ".setLiteral" => Set(s"${PythonAstVisitor.builtinPrefix}set") + case _ => super.getTypesFromCall(c) - /** Replaces the `this` prefix with the Pythonic `self` prefix for instance methods of functions - * local to this compilation unit. - */ - private def fromNodeToLocalPythonKey(node: AstNode): Option[LocalKey] = - node match - case n: Method => Option(CallAlias(n.name, Option("self"))) - case _ => SBKey.fromNodeToLocalKey(node) + /** Replaces the `this` prefix with the Pythonic `self` prefix for instance methods of functions + * local to this compilation unit. + */ + private def fromNodeToLocalPythonKey(node: AstNode): Option[LocalKey] = + node match + case n: Method => Option(CallAlias(n.name, Option("self"))) + case _ => SBKey.fromNodeToLocalKey(node) end RecoverForPythonFile diff --git a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/memop/AstNodeToMemoryOperationMap.scala b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/memop/AstNodeToMemoryOperationMap.scala index b8619d25..26db19bd 100644 --- a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/memop/AstNodeToMemoryOperationMap.scala +++ b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/memop/AstNodeToMemoryOperationMap.scala @@ -4,24 +4,24 @@ import io.appthreat.pythonparser.ast import scala.collection.mutable class AstNodeToMemoryOperationMap: - private class IdentityHashWrapper(private val astNode: ast.iast): - override def equals(o: Any): Boolean = - o match - case wrapper: IdentityHashWrapper => - astNode eq wrapper.astNode - case _ => - false + private class IdentityHashWrapper(private val astNode: ast.iast): + override def equals(o: Any): Boolean = + o match + case wrapper: IdentityHashWrapper => + astNode eq wrapper.astNode + case _ => + false - override def hashCode(): Int = - System.identityHashCode(astNode) + override def hashCode(): Int = + System.identityHashCode(astNode) - override def toString: String = astNode.toString + override def toString: String = astNode.toString - private val astNodeToMemOp = mutable.HashMap.empty[IdentityHashWrapper, MemoryOperation] + private val astNodeToMemOp = mutable.HashMap.empty[IdentityHashWrapper, MemoryOperation] - def put(astNode: ast.iast, memOp: MemoryOperation): Unit = - astNodeToMemOp.put(new IdentityHashWrapper(astNode), memOp) + def put(astNode: ast.iast, memOp: MemoryOperation): Unit = + astNodeToMemOp.put(new IdentityHashWrapper(astNode), memOp) - def get(astNode: ast.iast): Option[MemoryOperation] = - astNodeToMemOp.get(new IdentityHashWrapper(astNode)) + def get(astNode: ast.iast): Option[MemoryOperation] = + astNodeToMemOp.get(new IdentityHashWrapper(astNode)) end AstNodeToMemoryOperationMap diff --git a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/memop/MemoryOperation.scala b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/memop/MemoryOperation.scala index 89665bf0..d8e13c07 100644 --- a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/memop/MemoryOperation.scala +++ b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/memop/MemoryOperation.scala @@ -1,7 +1,7 @@ package io.appthreat.pysrc2cpg.memop sealed trait MemoryOperation: - override def toString: String = getClass.getSimpleName + override def toString: String = getClass.getSimpleName object Store extends MemoryOperation object Load extends MemoryOperation object Del extends MemoryOperation diff --git a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/memop/MemoryOperationCalculator.scala b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/memop/MemoryOperationCalculator.scala index 9306312c..6a7a1c36 100644 --- a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/memop/MemoryOperationCalculator.scala +++ b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pysrc2cpg/memop/MemoryOperationCalculator.scala @@ -73,458 +73,458 @@ import io.appthreat.pythonparser.ast.{ import scala.collection.mutable class MemoryOperationCalculator extends AstVisitor[Unit]: - private val stack = mutable.Stack.empty[MemoryOperation] - val astNodeToMemOp = new AstNodeToMemoryOperationMap() - val names = mutable.Set.empty[String] - - private def accept(astNode: iast): Unit = - astNode.accept(this) - - private def accept(astNodes: Iterable[iast]): Unit = - astNodes.foreach(accept) - - private def push(memOp: MemoryOperation): Unit = - stack.push(memOp) - - private def pop(): Unit = - stack.pop() - - override def visit(astNode: iast): Unit = ??? - - override def visit(mod: imod): Unit = ??? - - override def visit(module: Module): Unit = - accept(module.stmts) - - override def visit(stmt: istmt): Unit = ??? - - override def visit(functionDef: FunctionDef): Unit = - push(Load) - accept(functionDef.decorator_list) - accept(functionDef.args) - accept(functionDef.returns) - pop() - accept(functionDef.body) - - override def visit(functionDef: AsyncFunctionDef): Unit = - push(Load) - accept(functionDef.decorator_list) - accept(functionDef.args) - accept(functionDef.returns) - pop() - accept(functionDef.body) - - override def visit(classDef: ClassDef): Unit = - push(Load) - accept(classDef.decorator_list) - accept(classDef.bases) - accept(classDef.keywords) - pop() - accept(classDef.body) - - override def visit(ret: Return): Unit = - push(Load) - accept(ret.value) - pop() - - override def visit(delete: Delete): Unit = - push(Del) - accept(delete.targets) - pop() - - override def visit(assign: Assign): Unit = - push(Store) - accept(assign.targets) - pop() - push(Load) - accept(assign.value) - pop() - - override def visit(annAssign: AnnAssign): Unit = - push(Store) - accept(annAssign.target) - pop() - push(Load) - accept(annAssign.annotation) - accept(annAssign.value) - pop() - - override def visit(augAssign: AugAssign): Unit = - push(Store) - accept(augAssign.target) - pop() - push(Load) - accept(augAssign.value) - pop() - - override def visit(forStmt: For): Unit = - push(Store) - accept(forStmt.target) - pop() - push(Load) - accept(forStmt.iter) - pop() - accept(forStmt.body) - accept(forStmt.orelse) - - override def visit(forStmt: AsyncFor): Unit = - push(Store) - accept(forStmt.target) - pop() - push(Load) - accept(forStmt.iter) - pop() - accept(forStmt.body) - accept(forStmt.orelse) - - override def visit(whileStmt: While): Unit = - push(Load) - accept(whileStmt.test) - pop() - accept(whileStmt.body) - accept(whileStmt.orelse) - - override def visit(ifStmt: If): Unit = - push(Load) - accept(ifStmt.test) - pop() - accept(ifStmt.body) - accept(ifStmt.orelse) - - override def visit(withStmt: With): Unit = - accept(withStmt.items) - accept(withStmt.body) - - override def visit(withStmt: AsyncWith): Unit = - accept(withStmt.items) - accept(withStmt.body) - - override def visit(matchStmt: Match): Unit = - push(Load) - accept(matchStmt.subject) - accept(matchStmt.cases) - pop() - - override def visit(raise: Raise): Unit = - push(Load) - accept(raise.exc) - accept(raise.cause) - pop() - - override def visit(tryStmt: Try): Unit = - accept(tryStmt.body) - accept(tryStmt.handlers) - accept(tryStmt.orelse) - accept(tryStmt.finalbody) - - override def visit(assert: Assert): Unit = - push(Load) - accept(assert.test) - accept(assert.msg) - pop() - - override def visit(importStmt: Import): Unit = {} - - override def visit(importFrom: ImportFrom): Unit = {} - - override def visit(global: Global): Unit = {} - - override def visit(nonlocal: Nonlocal): Unit = {} - - override def visit(expr: Expr): Unit = - push(Load) - expr.value.accept(this) - pop() - - override def visit(pass: Pass): Unit = {} - - override def visit(break: Break): Unit = {} - - override def visit(continue: Continue): Unit = {} - - override def visit(raise: RaiseP2): Unit = - push(Load) - accept(raise.typ) - accept(raise.inst) - accept(raise.tback) - pop() - - override def visit(errorStatement: ErrorStatement): Unit = {} - - override def visit(expr: iexpr): Unit = ??? - - override def visit(boolOp: BoolOp): Unit = - accept(boolOp.values) - - override def visit(namedExpr: NamedExpr): Unit = - push(Store) - accept(namedExpr.target) - pop() - accept(namedExpr.value) - - override def visit(binOp: BinOp): Unit = - accept(binOp.left) - accept(binOp.right) - - override def visit(unaryOp: UnaryOp): Unit = - accept(unaryOp.operand) - - override def visit(lambda: Lambda): Unit = - push(Load) - accept(lambda.args) - pop() - accept(lambda.body) - - override def visit(ifExp: IfExp): Unit = - accept(ifExp.test) - accept(ifExp.body) - accept(ifExp.orelse) + private val stack = mutable.Stack.empty[MemoryOperation] + val astNodeToMemOp = new AstNodeToMemoryOperationMap() + val names = mutable.Set.empty[String] + + private def accept(astNode: iast): Unit = + astNode.accept(this) + + private def accept(astNodes: Iterable[iast]): Unit = + astNodes.foreach(accept) + + private def push(memOp: MemoryOperation): Unit = + stack.push(memOp) + + private def pop(): Unit = + stack.pop() + + override def visit(astNode: iast): Unit = ??? + + override def visit(mod: imod): Unit = ??? + + override def visit(module: Module): Unit = + accept(module.stmts) + + override def visit(stmt: istmt): Unit = ??? + + override def visit(functionDef: FunctionDef): Unit = + push(Load) + accept(functionDef.decorator_list) + accept(functionDef.args) + accept(functionDef.returns) + pop() + accept(functionDef.body) + + override def visit(functionDef: AsyncFunctionDef): Unit = + push(Load) + accept(functionDef.decorator_list) + accept(functionDef.args) + accept(functionDef.returns) + pop() + accept(functionDef.body) + + override def visit(classDef: ClassDef): Unit = + push(Load) + accept(classDef.decorator_list) + accept(classDef.bases) + accept(classDef.keywords) + pop() + accept(classDef.body) + + override def visit(ret: Return): Unit = + push(Load) + accept(ret.value) + pop() + + override def visit(delete: Delete): Unit = + push(Del) + accept(delete.targets) + pop() + + override def visit(assign: Assign): Unit = + push(Store) + accept(assign.targets) + pop() + push(Load) + accept(assign.value) + pop() + + override def visit(annAssign: AnnAssign): Unit = + push(Store) + accept(annAssign.target) + pop() + push(Load) + accept(annAssign.annotation) + accept(annAssign.value) + pop() + + override def visit(augAssign: AugAssign): Unit = + push(Store) + accept(augAssign.target) + pop() + push(Load) + accept(augAssign.value) + pop() + + override def visit(forStmt: For): Unit = + push(Store) + accept(forStmt.target) + pop() + push(Load) + accept(forStmt.iter) + pop() + accept(forStmt.body) + accept(forStmt.orelse) + + override def visit(forStmt: AsyncFor): Unit = + push(Store) + accept(forStmt.target) + pop() + push(Load) + accept(forStmt.iter) + pop() + accept(forStmt.body) + accept(forStmt.orelse) + + override def visit(whileStmt: While): Unit = + push(Load) + accept(whileStmt.test) + pop() + accept(whileStmt.body) + accept(whileStmt.orelse) + + override def visit(ifStmt: If): Unit = + push(Load) + accept(ifStmt.test) + pop() + accept(ifStmt.body) + accept(ifStmt.orelse) + + override def visit(withStmt: With): Unit = + accept(withStmt.items) + accept(withStmt.body) + + override def visit(withStmt: AsyncWith): Unit = + accept(withStmt.items) + accept(withStmt.body) + + override def visit(matchStmt: Match): Unit = + push(Load) + accept(matchStmt.subject) + accept(matchStmt.cases) + pop() + + override def visit(raise: Raise): Unit = + push(Load) + accept(raise.exc) + accept(raise.cause) + pop() + + override def visit(tryStmt: Try): Unit = + accept(tryStmt.body) + accept(tryStmt.handlers) + accept(tryStmt.orelse) + accept(tryStmt.finalbody) + + override def visit(assert: Assert): Unit = + push(Load) + accept(assert.test) + accept(assert.msg) + pop() + + override def visit(importStmt: Import): Unit = {} + + override def visit(importFrom: ImportFrom): Unit = {} + + override def visit(global: Global): Unit = {} + + override def visit(nonlocal: Nonlocal): Unit = {} + + override def visit(expr: Expr): Unit = + push(Load) + expr.value.accept(this) + pop() + + override def visit(pass: Pass): Unit = {} + + override def visit(break: Break): Unit = {} + + override def visit(continue: Continue): Unit = {} + + override def visit(raise: RaiseP2): Unit = + push(Load) + accept(raise.typ) + accept(raise.inst) + accept(raise.tback) + pop() + + override def visit(errorStatement: ErrorStatement): Unit = {} + + override def visit(expr: iexpr): Unit = ??? + + override def visit(boolOp: BoolOp): Unit = + accept(boolOp.values) + + override def visit(namedExpr: NamedExpr): Unit = + push(Store) + accept(namedExpr.target) + pop() + accept(namedExpr.value) + + override def visit(binOp: BinOp): Unit = + accept(binOp.left) + accept(binOp.right) + + override def visit(unaryOp: UnaryOp): Unit = + accept(unaryOp.operand) + + override def visit(lambda: Lambda): Unit = + push(Load) + accept(lambda.args) + pop() + accept(lambda.body) + + override def visit(ifExp: IfExp): Unit = + accept(ifExp.test) + accept(ifExp.body) + accept(ifExp.orelse) - override def visit(dict: Dict): Unit = - accept(dict.keys.collect { case Some(key) => key }) - accept(dict.values) + override def visit(dict: Dict): Unit = + accept(dict.keys.collect { case Some(key) => key }) + accept(dict.values) - override def visit(set: ast.Set): Unit = - accept(set.elts) + override def visit(set: ast.Set): Unit = + accept(set.elts) - override def visit(listComp: ast.ListComp): Unit = - accept(listComp.elt) - accept(listComp.generators) + override def visit(listComp: ast.ListComp): Unit = + accept(listComp.elt) + accept(listComp.generators) - override def visit(setComp: ast.SetComp): Unit = - accept(setComp.elt) - accept(setComp.generators) + override def visit(setComp: ast.SetComp): Unit = + accept(setComp.elt) + accept(setComp.generators) - override def visit(dictComp: ast.DictComp): Unit = - accept(dictComp.key) - accept(dictComp.value) - accept(dictComp.generators) + override def visit(dictComp: ast.DictComp): Unit = + accept(dictComp.key) + accept(dictComp.value) + accept(dictComp.generators) - override def visit(generatorExp: ast.GeneratorExp): Unit = - accept(generatorExp.elt) - accept(generatorExp.generators) + override def visit(generatorExp: ast.GeneratorExp): Unit = + accept(generatorExp.elt) + accept(generatorExp.generators) - override def visit(await: ast.Await): Unit = - accept(await.value) + override def visit(await: ast.Await): Unit = + accept(await.value) - override def visit(yieldExpr: ast.Yield): Unit = - accept(yieldExpr.value) + override def visit(yieldExpr: ast.Yield): Unit = + accept(yieldExpr.value) - override def visit(yieldFrom: ast.YieldFrom): Unit = - accept(yieldFrom.value) + override def visit(yieldFrom: ast.YieldFrom): Unit = + accept(yieldFrom.value) - override def visit(compare: ast.Compare): Unit = - accept(compare.left) - accept(compare.comparators) + override def visit(compare: ast.Compare): Unit = + accept(compare.left) + accept(compare.comparators) - override def visit(call: ast.Call): Unit = - assert(stack.head == Load) - accept(call.func) - accept(call.args) - accept(call.keywords) + override def visit(call: ast.Call): Unit = + assert(stack.head == Load) + accept(call.func) + accept(call.args) + accept(call.keywords) - override def visit(formattedValue: FormattedValue): Unit = - assert(stack.head == Load) - accept(formattedValue.value) + override def visit(formattedValue: FormattedValue): Unit = + assert(stack.head == Load) + accept(formattedValue.value) - override def visit(joinedString: JoinedString): Unit = - assert(stack.head == Load) - accept(joinedString.values) + override def visit(joinedString: JoinedString): Unit = + assert(stack.head == Load) + accept(joinedString.values) - override def visit(constant: ast.Constant): Unit = {} + override def visit(constant: ast.Constant): Unit = {} - override def visit(attribute: ast.Attribute): Unit = - push(Load) - accept(attribute.value) - pop() - astNodeToMemOp.put(attribute, stack.head) + override def visit(attribute: ast.Attribute): Unit = + push(Load) + accept(attribute.value) + pop() + astNodeToMemOp.put(attribute, stack.head) - override def visit(subscript: ast.Subscript): Unit = - push(Load) - accept(subscript.value) - accept(subscript.slice) - pop() - astNodeToMemOp.put(subscript, stack.head) + override def visit(subscript: ast.Subscript): Unit = + push(Load) + accept(subscript.value) + accept(subscript.slice) + pop() + astNodeToMemOp.put(subscript, stack.head) - override def visit(starred: ast.Starred): Unit = - accept(starred.value) - astNodeToMemOp.put(starred, stack.head) + override def visit(starred: ast.Starred): Unit = + accept(starred.value) + astNodeToMemOp.put(starred, stack.head) - override def visit(name: ast.Name): Unit = - astNodeToMemOp.put(name, stack.head) - names.add(name.id) + override def visit(name: ast.Name): Unit = + astNodeToMemOp.put(name, stack.head) + names.add(name.id) - override def visit(list: ast.List): Unit = - accept(list.elts) - astNodeToMemOp.put(list, stack.head) + override def visit(list: ast.List): Unit = + accept(list.elts) + astNodeToMemOp.put(list, stack.head) - override def visit(tuple: ast.Tuple): Unit = - accept(tuple.elts) - astNodeToMemOp.put(tuple, stack.head) + override def visit(tuple: ast.Tuple): Unit = + accept(tuple.elts) + astNodeToMemOp.put(tuple, stack.head) - override def visit(slice: ast.Slice): Unit = - push(Load) - accept(slice.lower) - accept(slice.upper) - accept(slice.step) - pop() + override def visit(slice: ast.Slice): Unit = + push(Load) + accept(slice.lower) + accept(slice.upper) + accept(slice.step) + pop() - override def visit(stringExpList: ast.StringExpList): Unit = - accept(stringExpList.elts) + override def visit(stringExpList: ast.StringExpList): Unit = + accept(stringExpList.elts) - override def visit(boolop: ast.iboolop): Unit = {} + override def visit(boolop: ast.iboolop): Unit = {} - override def visit(and: ast.And.type): Unit = {} + override def visit(and: ast.And.type): Unit = {} - override def visit(or: ast.Or.type): Unit = {} + override def visit(or: ast.Or.type): Unit = {} - override def visit(operator: ast.ioperator): Unit = {} + override def visit(operator: ast.ioperator): Unit = {} - override def visit(add: ast.Add.type): Unit = {} + override def visit(add: ast.Add.type): Unit = {} - override def visit(sub: ast.Sub.type): Unit = {} + override def visit(sub: ast.Sub.type): Unit = {} - override def visit(mult: ast.Mult.type): Unit = {} + override def visit(mult: ast.Mult.type): Unit = {} - override def visit(matMult: ast.MatMult.type): Unit = {} + override def visit(matMult: ast.MatMult.type): Unit = {} - override def visit(div: ast.Div.type): Unit = {} + override def visit(div: ast.Div.type): Unit = {} - override def visit(mod: ast.Mod.type): Unit = {} + override def visit(mod: ast.Mod.type): Unit = {} - override def visit(pow: ast.Pow.type): Unit = {} + override def visit(pow: ast.Pow.type): Unit = {} - override def visit(lShift: ast.LShift.type): Unit = {} + override def visit(lShift: ast.LShift.type): Unit = {} - override def visit(rShift: ast.RShift.type): Unit = {} + override def visit(rShift: ast.RShift.type): Unit = {} - override def visit(bitOr: ast.BitOr.type): Unit = {} + override def visit(bitOr: ast.BitOr.type): Unit = {} - override def visit(bitXor: ast.BitXor.type): Unit = {} + override def visit(bitXor: ast.BitXor.type): Unit = {} - override def visit(bitAnd: ast.BitAnd.type): Unit = {} + override def visit(bitAnd: ast.BitAnd.type): Unit = {} - override def visit(floorDiv: ast.FloorDiv.type): Unit = {} + override def visit(floorDiv: ast.FloorDiv.type): Unit = {} - override def visit(unaryop: ast.iunaryop): Unit = {} + override def visit(unaryop: ast.iunaryop): Unit = {} - override def visit(invert: ast.Invert.type): Unit = {} + override def visit(invert: ast.Invert.type): Unit = {} - override def visit(not: ast.Not.type): Unit = {} + override def visit(not: ast.Not.type): Unit = {} - override def visit(uAdd: ast.UAdd.type): Unit = {} + override def visit(uAdd: ast.UAdd.type): Unit = {} - override def visit(uSub: ast.USub.type): Unit = {} + override def visit(uSub: ast.USub.type): Unit = {} - override def visit(compop: ast.icompop): Unit = {} + override def visit(compop: ast.icompop): Unit = {} - override def visit(eq: ast.Eq.type): Unit = {} + override def visit(eq: ast.Eq.type): Unit = {} - override def visit(notEq: ast.NotEq.type): Unit = {} + override def visit(notEq: ast.NotEq.type): Unit = {} - override def visit(lt: ast.Lt.type): Unit = {} + override def visit(lt: ast.Lt.type): Unit = {} - override def visit(ltE: ast.LtE.type): Unit = {} + override def visit(ltE: ast.LtE.type): Unit = {} - override def visit(gt: ast.Gt.type): Unit = {} + override def visit(gt: ast.Gt.type): Unit = {} - override def visit(gtE: ast.GtE.type): Unit = {} + override def visit(gtE: ast.GtE.type): Unit = {} - override def visit(is: ast.Is.type): Unit = {} + override def visit(is: ast.Is.type): Unit = {} - override def visit(isNot: ast.IsNot.type): Unit = {} + override def visit(isNot: ast.IsNot.type): Unit = {} - override def visit(in: ast.In.type): Unit = {} + override def visit(in: ast.In.type): Unit = {} - override def visit(notIn: ast.NotIn.type): Unit = {} + override def visit(notIn: ast.NotIn.type): Unit = {} - override def visit(comprehension: ast.Comprehension): Unit = - assert(stack.head == Load) - push(Store) - accept(comprehension.target) - pop() - accept(comprehension.iter) - accept(comprehension.ifs) + override def visit(comprehension: ast.Comprehension): Unit = + assert(stack.head == Load) + push(Store) + accept(comprehension.target) + pop() + accept(comprehension.iter) + accept(comprehension.ifs) - override def visit(exceptHandler: ast.ExceptHandler): Unit = - push(Load) - accept(exceptHandler.typ) - pop() - accept(exceptHandler.body) + override def visit(exceptHandler: ast.ExceptHandler): Unit = + push(Load) + accept(exceptHandler.typ) + pop() + accept(exceptHandler.body) - override def visit(arguments: ast.Arguments): Unit = - accept(arguments.posonlyargs) - accept(arguments.args) - accept(arguments.vararg) - accept(arguments.kwonlyargs) - accept(arguments.kw_defaults.collect { case Some(default) => default }) - accept(arguments.kw_arg) - accept(arguments.defaults) + override def visit(arguments: ast.Arguments): Unit = + accept(arguments.posonlyargs) + accept(arguments.args) + accept(arguments.vararg) + accept(arguments.kwonlyargs) + accept(arguments.kw_defaults.collect { case Some(default) => default }) + accept(arguments.kw_arg) + accept(arguments.defaults) - override def visit(arg: ast.Arg): Unit = - accept(arg.annotation) + override def visit(arg: ast.Arg): Unit = + accept(arg.annotation) - override def visit(constant: ast.iconstant): Unit = ??? + override def visit(constant: ast.iconstant): Unit = ??? - override def visit(stringConstant: ast.StringConstant): Unit = {} + override def visit(stringConstant: ast.StringConstant): Unit = {} - override def visit(joinedStringConstant: JoinedStringConstant): Unit = {} + override def visit(joinedStringConstant: JoinedStringConstant): Unit = {} - override def visit(boolConstant: ast.BoolConstant): Unit = {} + override def visit(boolConstant: ast.BoolConstant): Unit = {} - override def visit(intConstant: ast.IntConstant): Unit = {} + override def visit(intConstant: ast.IntConstant): Unit = {} - override def visit(intConstant: ast.FloatConstant): Unit = {} + override def visit(intConstant: ast.FloatConstant): Unit = {} - override def visit(imaginaryConstant: ast.ImaginaryConstant): Unit = {} + override def visit(imaginaryConstant: ast.ImaginaryConstant): Unit = {} - override def visit(noneConstant: ast.NoneConstant.type): Unit = {} + override def visit(noneConstant: ast.NoneConstant.type): Unit = {} - override def visit(ellipsisConstant: ast.EllipsisConstant.type): Unit = {} + override def visit(ellipsisConstant: ast.EllipsisConstant.type): Unit = {} - override def visit(keyword: ast.Keyword): Unit = - assert(stack.head == Load) - accept(keyword.value) + override def visit(keyword: ast.Keyword): Unit = + assert(stack.head == Load) + accept(keyword.value) - override def visit(alias: ast.Alias): Unit = {} + override def visit(alias: ast.Alias): Unit = {} - override def visit(withItem: ast.Withitem): Unit = - push(Load) - accept(withItem.context_expr) - pop() - push(Store) - accept(withItem.optional_vars) - pop() + override def visit(withItem: ast.Withitem): Unit = + push(Load) + accept(withItem.context_expr) + pop() + push(Store) + accept(withItem.optional_vars) + pop() - override def visit(matchCase: MatchCase): Unit = - accept(matchCase.pattern) - accept(matchCase.guard) - accept(matchCase.body) + override def visit(matchCase: MatchCase): Unit = + accept(matchCase.pattern) + accept(matchCase.guard) + accept(matchCase.body) - override def visit(matchValue: MatchValue): Unit = - accept(matchValue.value) + override def visit(matchValue: MatchValue): Unit = + accept(matchValue.value) - override def visit(matchSingleton: MatchSingleton): Unit = {} + override def visit(matchSingleton: MatchSingleton): Unit = {} - override def visit(matchSequence: MatchSequence): Unit = - accept(matchSequence.patterns) + override def visit(matchSequence: MatchSequence): Unit = + accept(matchSequence.patterns) - override def visit(matchMapping: MatchMapping): Unit = - accept(matchMapping.keys) - accept(matchMapping.patterns) + override def visit(matchMapping: MatchMapping): Unit = + accept(matchMapping.keys) + accept(matchMapping.patterns) - override def visit(matchClass: MatchClass): Unit = - accept(matchClass.cls) - accept(matchClass.patterns) - accept(matchClass.kwd_patterns) + override def visit(matchClass: MatchClass): Unit = + accept(matchClass.cls) + accept(matchClass.patterns) + accept(matchClass.kwd_patterns) - override def visit(matchStar: MatchStar): Unit = {} + override def visit(matchStar: MatchStar): Unit = {} - override def visit(matchAs: MatchAs): Unit = - accept(matchAs.pattern) + override def visit(matchAs: MatchAs): Unit = + accept(matchAs.pattern) - override def visit(matchOr: MatchOr): Unit = - accept(matchOr.patterns) + override def visit(matchOr: MatchOr): Unit = + accept(matchOr.patterns) - override def visit(typeIgnore: ast.TypeIgnore): Unit = {} + override def visit(typeIgnore: ast.TypeIgnore): Unit = {} end MemoryOperationCalculator diff --git a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pythonparser/AstPrinter.scala b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pythonparser/AstPrinter.scala index 53bbf461..c076cf9f 100644 --- a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pythonparser/AstPrinter.scala +++ b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pythonparser/AstPrinter.scala @@ -125,624 +125,624 @@ import io.appthreat.pythonparser.ast.* import scala.collection.immutable class AstPrinter(indentStr: String) extends AstVisitor[String]: - private val ls = "\n" - - def print(astNode: iast): String = - astNode.accept(this) - - def printIndented(astNode: iast): String = - val printStr = astNode.accept(this) - - indentStr + printStr.replaceAll(ls, ls + indentStr) - - override def visit(ast: iast): String = ??? + private val ls = "\n" + + def print(astNode: iast): String = + astNode.accept(this) + + def printIndented(astNode: iast): String = + val printStr = astNode.accept(this) + + indentStr + printStr.replaceAll(ls, ls + indentStr) + + override def visit(ast: iast): String = ??? + + override def visit(mod: imod): String = ??? + + override def visit(module: Module): String = + module.stmts.map(print).mkString(ls) + + override def visit(stmt: istmt): String = ??? + + override def visit(functionDef: FunctionDef): String = + functionDef.decorator_list.map(d => "@" + print(d) + ls).mkString("") + + "def " + functionDef.name + "(" + print(functionDef.args) + ")" + + functionDef.returns.map(r => " -> " + print(r)).getOrElse("") + + ":" + functionDef.body.map(printIndented).mkString(ls, ls, "") + + override def visit(functionDef: AsyncFunctionDef): String = + functionDef.decorator_list.map(d => "@" + print(d) + ls).mkString("") + + "async def " + functionDef.name + "(" + print(functionDef.args) + ")" + + functionDef.returns.map(r => " -> " + print(r)).getOrElse("") + + ":" + functionDef.body.map(printIndented).mkString(ls, ls, "") + + override def visit(classDef: ClassDef): String = + val optionArgEndComma = + if classDef.bases.nonEmpty && classDef.keywords.nonEmpty then ", " else "" + + classDef.decorator_list.map(d => "@" + print(d) + ls).mkString("") + + "class " + classDef.name + + "(" + + classDef.bases.map(print).mkString(", ") + + optionArgEndComma + + classDef.keywords.map(print).mkString(", ") + + ")" + ":" + + classDef.body.map(printIndented).mkString(ls, ls, "") + + override def visit(ret: Return): String = + "return" + ret.value.map(v => " " + print(v)).getOrElse("") + + override def visit(delete: Delete): String = + "del " + delete.targets.map(print).mkString(", ") + + override def visit(assign: Assign): String = + assign.targets.map(print).mkString("", " = ", " = ") + print(assign.value) + + override def visit(annAssign: AnnAssign): String = + print(annAssign.target) + + ": " + print(annAssign.annotation) + + annAssign.value.map(v => " = " + print(v)).getOrElse("") + + override def visit(augAssign: AugAssign): String = + print(augAssign.target) + + " " + print(augAssign.op) + "= " + + print(augAssign.value) + + override def visit(forStmt: For): String = + "for " + print(forStmt.target) + " in " + print(forStmt.iter) + ":" + + forStmt.body.map(printIndented).mkString(ls, ls, "") + + (if forStmt.orelse.nonEmpty then + s"${ls}else:" + + forStmt.orelse.map(printIndented).mkString(ls, ls, "") + else "") + + override def visit(forStmt: AsyncFor): String = + "async for " + print(forStmt.target) + " in " + print(forStmt.iter) + ":" + + forStmt.body.map(printIndented).mkString(ls, ls, "") + + (if forStmt.orelse.nonEmpty then + s"${ls}else:" + + forStmt.orelse.map(printIndented).mkString(ls, ls, "") + else "") + + override def visit(whileStmt: While): String = + "while " + print(whileStmt.test) + ":" + + whileStmt.body.map(printIndented).mkString(ls, ls, "") + + (if whileStmt.orelse.nonEmpty then + s"${ls}else:" + + whileStmt.orelse.map(printIndented).mkString(ls, ls, "") + else "") + + override def visit(ifStmt: If): String = + val elseString = + ifStmt.orelse.size match + case 0 => "" + case 1 if ifStmt.orelse.head.isInstanceOf[If] => + s"${ls}el" + print(ifStmt.orelse.head) + case _ => + s"${ls}else:" + + ifStmt.orelse.map(printIndented).mkString(ls, ls, "") + + "if " + print(ifStmt.test) + ":" + + ifStmt.body.map(printIndented).mkString(ls, ls, "") + + elseString + + override def visit(withStmt: With): String = + "with " + withStmt.items.map(print).mkString(", ") + ":" + + withStmt.body.map(printIndented).mkString(ls, ls, "") + + override def visit(withStmt: AsyncWith): String = + "async with " + withStmt.items.map(print).mkString(", ") + ":" + + withStmt.body.map(printIndented).mkString(ls, ls, "") + + override def visit(matchStmt: Match): String = + val subjectSuffix = matchStmt.subject match + case _: Starred => + "," + case _ => + "" + "match " + print(matchStmt.subject) + subjectSuffix + ":" + + matchStmt.cases.map(printIndented).mkString(ls, ls, "") + + override def visit(raise: Raise): String = + "raise" + raise.exc.map(e => " " + print(e)).getOrElse("") + + raise.cause.map(c => " from " + print(c)).getOrElse("") + + override def visit(tryStmt: Try): String = + val elseString = + if tryStmt.orelse.nonEmpty then + s"${ls}else:" + + tryStmt.orelse.map(printIndented).mkString(ls, ls, "") + else + "" - override def visit(mod: imod): String = ??? + val finallyString = + if tryStmt.finalbody.nonEmpty then + s"${ls}finally:" + + tryStmt.finalbody.map(printIndented).mkString(ls, ls, "") + else + "" - override def visit(module: Module): String = - module.stmts.map(print).mkString(ls) + val handlersString = + if tryStmt.handlers.nonEmpty then + tryStmt.handlers.map(print).mkString(ls, ls, "") + else + "" - override def visit(stmt: istmt): String = ??? + "try:" + + tryStmt.body.map(printIndented).mkString(ls, ls, "") + + handlersString + + elseString + + finallyString + end visit - override def visit(functionDef: FunctionDef): String = - functionDef.decorator_list.map(d => "@" + print(d) + ls).mkString("") + - "def " + functionDef.name + "(" + print(functionDef.args) + ")" + - functionDef.returns.map(r => " -> " + print(r)).getOrElse("") + - ":" + functionDef.body.map(printIndented).mkString(ls, ls, "") + override def visit(assert: Assert): String = + "assert " + print(assert.test) + assert.msg.map(m => ", " + print(m)).getOrElse("") - override def visit(functionDef: AsyncFunctionDef): String = - functionDef.decorator_list.map(d => "@" + print(d) + ls).mkString("") + - "async def " + functionDef.name + "(" + print(functionDef.args) + ")" + - functionDef.returns.map(r => " -> " + print(r)).getOrElse("") + - ":" + functionDef.body.map(printIndented).mkString(ls, ls, "") + override def visit(importStmt: Import): String = + "import " + importStmt.names.map(print).mkString(", ") - override def visit(classDef: ClassDef): String = - val optionArgEndComma = - if classDef.bases.nonEmpty && classDef.keywords.nonEmpty then ", " else "" - - classDef.decorator_list.map(d => "@" + print(d) + ls).mkString("") + - "class " + classDef.name + - "(" + - classDef.bases.map(print).mkString(", ") + - optionArgEndComma + - classDef.keywords.map(print).mkString(", ") + - ")" + ":" + - classDef.body.map(printIndented).mkString(ls, ls, "") - - override def visit(ret: Return): String = - "return" + ret.value.map(v => " " + print(v)).getOrElse("") - - override def visit(delete: Delete): String = - "del " + delete.targets.map(print).mkString(", ") - - override def visit(assign: Assign): String = - assign.targets.map(print).mkString("", " = ", " = ") + print(assign.value) - - override def visit(annAssign: AnnAssign): String = - print(annAssign.target) + - ": " + print(annAssign.annotation) + - annAssign.value.map(v => " = " + print(v)).getOrElse("") - - override def visit(augAssign: AugAssign): String = - print(augAssign.target) + - " " + print(augAssign.op) + "= " + - print(augAssign.value) - - override def visit(forStmt: For): String = - "for " + print(forStmt.target) + " in " + print(forStmt.iter) + ":" + - forStmt.body.map(printIndented).mkString(ls, ls, "") + - (if forStmt.orelse.nonEmpty then - s"${ls}else:" + - forStmt.orelse.map(printIndented).mkString(ls, ls, "") - else "") - - override def visit(forStmt: AsyncFor): String = - "async for " + print(forStmt.target) + " in " + print(forStmt.iter) + ":" + - forStmt.body.map(printIndented).mkString(ls, ls, "") + - (if forStmt.orelse.nonEmpty then - s"${ls}else:" + - forStmt.orelse.map(printIndented).mkString(ls, ls, "") - else "") - - override def visit(whileStmt: While): String = - "while " + print(whileStmt.test) + ":" + - whileStmt.body.map(printIndented).mkString(ls, ls, "") + - (if whileStmt.orelse.nonEmpty then - s"${ls}else:" + - whileStmt.orelse.map(printIndented).mkString(ls, ls, "") - else "") - - override def visit(ifStmt: If): String = - val elseString = - ifStmt.orelse.size match - case 0 => "" - case 1 if ifStmt.orelse.head.isInstanceOf[If] => - s"${ls}el" + print(ifStmt.orelse.head) - case _ => - s"${ls}else:" + - ifStmt.orelse.map(printIndented).mkString(ls, ls, "") - - "if " + print(ifStmt.test) + ":" + - ifStmt.body.map(printIndented).mkString(ls, ls, "") + - elseString - - override def visit(withStmt: With): String = - "with " + withStmt.items.map(print).mkString(", ") + ":" + - withStmt.body.map(printIndented).mkString(ls, ls, "") - - override def visit(withStmt: AsyncWith): String = - "async with " + withStmt.items.map(print).mkString(", ") + ":" + - withStmt.body.map(printIndented).mkString(ls, ls, "") - - override def visit(matchStmt: Match): String = - val subjectSuffix = matchStmt.subject match - case _: Starred => - "," - case _ => - "" - "match " + print(matchStmt.subject) + subjectSuffix + ":" + - matchStmt.cases.map(printIndented).mkString(ls, ls, "") - - override def visit(raise: Raise): String = - "raise" + raise.exc.map(e => " " + print(e)).getOrElse("") + - raise.cause.map(c => " from " + print(c)).getOrElse("") - - override def visit(tryStmt: Try): String = - val elseString = - if tryStmt.orelse.nonEmpty then - s"${ls}else:" + - tryStmt.orelse.map(printIndented).mkString(ls, ls, "") - else - "" - - val finallyString = - if tryStmt.finalbody.nonEmpty then - s"${ls}finally:" + - tryStmt.finalbody.map(printIndented).mkString(ls, ls, "") - else - "" - - val handlersString = - if tryStmt.handlers.nonEmpty then - tryStmt.handlers.map(print).mkString(ls, ls, "") - else - "" - - "try:" + - tryStmt.body.map(printIndented).mkString(ls, ls, "") + - handlersString + - elseString + - finallyString - end visit - - override def visit(assert: Assert): String = - "assert " + print(assert.test) + assert.msg.map(m => ", " + print(m)).getOrElse("") - - override def visit(importStmt: Import): String = - "import " + importStmt.names.map(print).mkString(", ") - - override def visit(importFrom: ImportFrom): String = - val relativeImportDots = - if importFrom.level != 0 then - " " + "." * importFrom.level - else - "" - "from" + relativeImportDots + importFrom.module.map(m => " " + m).getOrElse("") + - " import " + importFrom.names.map(print).mkString(", ") - - override def visit(global: Global): String = - "global " + global.names.mkString(", ") - - override def visit(nonlocal: Nonlocal): String = - "nonlocal " + nonlocal.names.mkString(", ") - - override def visit(expr: Expr): String = - print(expr.value) - - override def visit(pass: Pass): String = - "pass" - - override def visit(break: Break): String = - "break" - - override def visit(continue: Continue): String = - "continue" - - override def visit(raise: RaiseP2): String = - "raise" + raise.typ.map(t => " " + print(t)).getOrElse("") + - raise.inst.map(i => ", " + print(i)).getOrElse("") + - raise.tback.map(t => ", " + print(t)).getOrElse("") - - override def visit(errorStmt: ErrorStatement): String = - "" - - override def visit(expr: iexpr): String = ??? - - override def visit(boolOp: BoolOp): String = - val opString = " " + print(boolOp.op) + " " - boolOp.values.map(print).mkString(opString) - - override def visit(namedExpr: NamedExpr): String = - print(namedExpr.target) + " := " + print(namedExpr.value) - - override def visit(binOp: BinOp): String = - print(binOp.left) + " " + print(binOp.op) + " " + print(binOp.right) - - override def visit(unaryOp: UnaryOp): String = - val opString = unaryOp.op match - case Not => - print(unaryOp.op) + " " - case _ => - print(unaryOp.op) - opString + print(unaryOp.operand) - - override def visit(lambda: Lambda): String = - val argStr = print(lambda.args) - if argStr.nonEmpty then - "lambda " + argStr + ": " + print(lambda.body) + override def visit(importFrom: ImportFrom): String = + val relativeImportDots = + if importFrom.level != 0 then + " " + "." * importFrom.level else - "lambda: " + print(lambda.body) - - override def visit(ifExp: IfExp): String = - print(ifExp.body) + " if " + print(ifExp.test) + " else " + print(ifExp.orelse) - - override def visit(dict: Dict): String = - "{" + dict.keys - .zip(dict.values) - .map { case (key, value) => - key match - case Some(k) => - print(k) + ":" + print(value) - case None => - "**" + print(value) - } - .mkString(", ") + "}" - - override def visit(set: ast.Set): String = - "{" + set.elts.map(print).mkString(", ") + "}" - - override def visit(listComp: ListComp): String = - "[" + print(listComp.elt) + listComp.generators.map(print).mkString("") + "]" + "" + "from" + relativeImportDots + importFrom.module.map(m => " " + m).getOrElse("") + + " import " + importFrom.names.map(print).mkString(", ") - override def visit(setComp: SetComp): String = - "{" + print(setComp.elt) + setComp.generators.map(print).mkString("") + "}" + override def visit(global: Global): String = + "global " + global.names.mkString(", ") - override def visit(dictComp: DictComp): String = - "{" + print(dictComp.key) + ":" + print(dictComp.value) + - dictComp.generators.map(print).mkString("") + "}" + override def visit(nonlocal: Nonlocal): String = + "nonlocal " + nonlocal.names.mkString(", ") + + override def visit(expr: Expr): String = + print(expr.value) + + override def visit(pass: Pass): String = + "pass" + + override def visit(break: Break): String = + "break" + + override def visit(continue: Continue): String = + "continue" + + override def visit(raise: RaiseP2): String = + "raise" + raise.typ.map(t => " " + print(t)).getOrElse("") + + raise.inst.map(i => ", " + print(i)).getOrElse("") + + raise.tback.map(t => ", " + print(t)).getOrElse("") + + override def visit(errorStmt: ErrorStatement): String = + "" + + override def visit(expr: iexpr): String = ??? + + override def visit(boolOp: BoolOp): String = + val opString = " " + print(boolOp.op) + " " + boolOp.values.map(print).mkString(opString) + + override def visit(namedExpr: NamedExpr): String = + print(namedExpr.target) + " := " + print(namedExpr.value) + + override def visit(binOp: BinOp): String = + print(binOp.left) + " " + print(binOp.op) + " " + print(binOp.right) + + override def visit(unaryOp: UnaryOp): String = + val opString = unaryOp.op match + case Not => + print(unaryOp.op) + " " + case _ => + print(unaryOp.op) + opString + print(unaryOp.operand) + + override def visit(lambda: Lambda): String = + val argStr = print(lambda.args) + if argStr.nonEmpty then + "lambda " + argStr + ": " + print(lambda.body) + else + "lambda: " + print(lambda.body) + + override def visit(ifExp: IfExp): String = + print(ifExp.body) + " if " + print(ifExp.test) + " else " + print(ifExp.orelse) + + override def visit(dict: Dict): String = + "{" + dict.keys + .zip(dict.values) + .map { case (key, value) => + key match + case Some(k) => + print(k) + ":" + print(value) + case None => + "**" + print(value) + } + .mkString(", ") + "}" + + override def visit(set: ast.Set): String = + "{" + set.elts.map(print).mkString(", ") + "}" + + override def visit(listComp: ListComp): String = + "[" + print(listComp.elt) + listComp.generators.map(print).mkString("") + "]" + + override def visit(setComp: SetComp): String = + "{" + print(setComp.elt) + setComp.generators.map(print).mkString("") + "}" + + override def visit(dictComp: DictComp): String = + "{" + print(dictComp.key) + ":" + print(dictComp.value) + + dictComp.generators.map(print).mkString("") + "}" + + override def visit(generatorExp: GeneratorExp): String = + "(" + print(generatorExp.elt) + generatorExp.generators.map(print).mkString("") + ")" + + override def visit(await: Await): String = + "await " + print(await.value) + + override def visit(yieldExpr: Yield): String = + "yield" + yieldExpr.value.map(v => " " + print(v)).getOrElse("") + + override def visit(yieldFrom: YieldFrom): String = + "yield from " + print(yieldFrom.value) + + override def visit(compare: Compare): String = + print(compare.left) + compare.ops + .zip(compare.comparators) + .map { case (op, comparator) => + " " + print(op) + " " + print(comparator) + } + .mkString("") + + override def visit(call: Call): String = + if call.args.size == 1 && call.args.head.isInstanceOf[GeneratorExp] then + // Special case in order to avoid double parenthesis since GeneratorExp adds a + // set of parenthesis on its own. + print(call.func) + print(call.args.head) + else + val optionArgEndComma = + if call.args.nonEmpty && call.keywords.nonEmpty then ", " else "" + print(call.func) + "(" + call.args.map(print).mkString(", ") + optionArgEndComma + + call.keywords.map(print).mkString(", ") + ")" - override def visit(generatorExp: GeneratorExp): String = - "(" + print(generatorExp.elt) + generatorExp.generators.map(print).mkString("") + ")" + override def visit(formattedValue: FormattedValue): String = + val equalSignStr = if formattedValue.equalSign then "=" else "" + val conversionStr = formattedValue.conversion match + case -1 => "" + case 115 => "!s" + case 114 => "!r" + case 97 => "!a" - override def visit(await: Await): String = - "await " + print(await.value) + val formatSpecStr = formattedValue.format_spec match + case Some(formatSpec) => ":" + formatSpec + case None => "" - override def visit(yieldExpr: Yield): String = - "yield" + yieldExpr.value.map(v => " " + print(v)).getOrElse("") + "{" + print(formattedValue.value) + + equalSignStr + + conversionStr + + formatSpecStr + + "}" - override def visit(yieldFrom: YieldFrom): String = - "yield from " + print(yieldFrom.value) + override def visit(joinedString: JoinedString): String = + joinedString.prefix + joinedString.quote + + joinedString.values.map(print).mkString("") + joinedString.quote - override def visit(compare: Compare): String = - print(compare.left) + compare.ops - .zip(compare.comparators) - .map { case (op, comparator) => - " " + print(op) + " " + print(comparator) - } - .mkString("") + override def visit(constant: Constant): String = + print(constant.value) - override def visit(call: Call): String = - if call.args.size == 1 && call.args.head.isInstanceOf[GeneratorExp] then - // Special case in order to avoid double parenthesis since GeneratorExp adds a - // set of parenthesis on its own. - print(call.func) + print(call.args.head) - else - val optionArgEndComma = - if call.args.nonEmpty && call.keywords.nonEmpty then ", " else "" - print(call.func) + "(" + call.args.map(print).mkString(", ") + optionArgEndComma + - call.keywords.map(print).mkString(", ") + ")" - - override def visit(formattedValue: FormattedValue): String = - val equalSignStr = if formattedValue.equalSign then "=" else "" - val conversionStr = formattedValue.conversion match - case -1 => "" - case 115 => "!s" - case 114 => "!r" - case 97 => "!a" - - val formatSpecStr = formattedValue.format_spec match - case Some(formatSpec) => ":" + formatSpec - case None => "" - - "{" + print(formattedValue.value) + - equalSignStr + - conversionStr + - formatSpecStr + - "}" - - override def visit(joinedString: JoinedString): String = - joinedString.prefix + joinedString.quote + - joinedString.values.map(print).mkString("") + joinedString.quote - - override def visit(constant: Constant): String = - print(constant.value) - - override def visit(attribute: Attribute): String = - print(attribute.value) + "." + attribute.attr - - override def visit(subscript: Subscript): String = - print(subscript.value) + "[" + print(subscript.slice) + "]" - - override def visit(starred: Starred): String = - "*" + print(starred.value) - - override def visit(name: Name): String = - name.id - - override def visit(list: ast.List): String = - "[" + list.elts.map(print).mkString(", ") + "]" - - override def visit(tuple: Tuple): String = - if tuple.elts.size == 1 then - "(" + print(tuple.elts.head) + ",)" - else - "(" + tuple.elts.map(print).mkString(",") + ")" + override def visit(attribute: Attribute): String = + print(attribute.value) + "." + attribute.attr - override def visit(slice: Slice): String = - slice.lower.map(print).getOrElse("") + - ":" + slice.upper.map(print).getOrElse("") + - slice.step.map(expr => ":" + print(expr)).getOrElse("") + override def visit(subscript: Subscript): String = + print(subscript.value) + "[" + print(subscript.slice) + "]" - override def visit(stringExpList: StringExpList): String = - stringExpList.elts.map(print).mkString(" ") + override def visit(starred: Starred): String = + "*" + print(starred.value) - override def visit(alias: Alias): String = - alias.name + alias.asName.map(n => " as " + n).getOrElse("") + override def visit(name: Name): String = + name.id - override def visit(boolop: iboolop): String = ??? + override def visit(list: ast.List): String = + "[" + list.elts.map(print).mkString(", ") + "]" - override def visit(and: And.type): String = - "and" + override def visit(tuple: Tuple): String = + if tuple.elts.size == 1 then + "(" + print(tuple.elts.head) + ",)" + else + "(" + tuple.elts.map(print).mkString(",") + ")" - override def visit(or: Or.type): String = - "or" + override def visit(slice: Slice): String = + slice.lower.map(print).getOrElse("") + + ":" + slice.upper.map(print).getOrElse("") + + slice.step.map(expr => ":" + print(expr)).getOrElse("") - override def visit(compop: icompop): String = ??? + override def visit(stringExpList: StringExpList): String = + stringExpList.elts.map(print).mkString(" ") - override def visit(eq: Eq.type): String = - "==" + override def visit(alias: Alias): String = + alias.name + alias.asName.map(n => " as " + n).getOrElse("") - override def visit(noteq: NotEq.type): String = - "!=" + override def visit(boolop: iboolop): String = ??? - override def visit(lt: Lt.type): String = - "<" + override def visit(and: And.type): String = + "and" - override def visit(ltE: LtE.type): String = - "<=" + override def visit(or: Or.type): String = + "or" - override def visit(gt: Gt.type): String = - ">" + override def visit(compop: icompop): String = ??? - override def visit(gtE: GtE.type): String = - ">=" + override def visit(eq: Eq.type): String = + "==" - override def visit(is: Is.type): String = - "is" + override def visit(noteq: NotEq.type): String = + "!=" - override def visit(isNot: IsNot.type): String = - "is not" + override def visit(lt: Lt.type): String = + "<" - override def visit(in: In.type): String = - "in" + override def visit(ltE: LtE.type): String = + "<=" - override def visit(notIn: NotIn.type): String = - "not in" + override def visit(gt: Gt.type): String = + ">" - override def visit(constant: iconstant): String = ??? + override def visit(gtE: GtE.type): String = + ">=" - override def visit(stringConstant: StringConstant): String = - stringConstant.prefix + stringConstant.quote + stringConstant.value + stringConstant.quote + override def visit(is: Is.type): String = + "is" - override def visit(joinedStringConstant: JoinedStringConstant): String = - joinedStringConstant.value + override def visit(isNot: IsNot.type): String = + "is not" - override def visit(boolConstant: BoolConstant): String = - if boolConstant.value then - "True" - else - "False" + override def visit(in: In.type): String = + "in" - override def visit(intConstant: IntConstant): String = - intConstant.value + override def visit(notIn: NotIn.type): String = + "not in" - override def visit(floatConstant: FloatConstant): String = - floatConstant.value + override def visit(constant: iconstant): String = ??? - override def visit(imaginaryConstant: ImaginaryConstant): String = - imaginaryConstant.value + override def visit(stringConstant: StringConstant): String = + stringConstant.prefix + stringConstant.quote + stringConstant.value + stringConstant.quote - override def visit(noneConstant: NoneConstant.type): String = - "None" + override def visit(joinedStringConstant: JoinedStringConstant): String = + joinedStringConstant.value - override def visit(ellipsisConstant: EllipsisConstant.type): String = - "..." + override def visit(boolConstant: BoolConstant): String = + if boolConstant.value then + "True" + else + "False" - override def visit(exceptHandler: ExceptHandler): String = - "except" + - exceptHandler.typ.map(t => " " + print(t)).getOrElse("") + - exceptHandler.name.map(n => " as " + n).getOrElse("") + - ":" + - exceptHandler.body.map(printIndented).mkString(ls, ls, "") + override def visit(intConstant: IntConstant): String = + intConstant.value - override def visit(keyword: Keyword): String = - keyword.arg match - case Some(argName) => - argName + " = " + print(keyword.value) - case None => - "**" + print(keyword.value) + override def visit(floatConstant: FloatConstant): String = + floatConstant.value - override def visit(operator: ioperator): String = ??? + override def visit(imaginaryConstant: ImaginaryConstant): String = + imaginaryConstant.value - override def visit(add: Add.type): String = - "+" + override def visit(noneConstant: NoneConstant.type): String = + "None" - override def visit(sub: Sub.type): String = - "-" + override def visit(ellipsisConstant: EllipsisConstant.type): String = + "..." - override def visit(mult: Mult.type): String = - "*" + override def visit(exceptHandler: ExceptHandler): String = + "except" + + exceptHandler.typ.map(t => " " + print(t)).getOrElse("") + + exceptHandler.name.map(n => " as " + n).getOrElse("") + + ":" + + exceptHandler.body.map(printIndented).mkString(ls, ls, "") - override def visit(matMult: MatMult.type): String = - "@" + override def visit(keyword: Keyword): String = + keyword.arg match + case Some(argName) => + argName + " = " + print(keyword.value) + case None => + "**" + print(keyword.value) - override def visit(div: Div.type): String = - "/" - - override def visit(mod: Mod.type): String = - "%" + override def visit(operator: ioperator): String = ??? - override def visit(pow: Pow.type): String = - "**" + override def visit(add: Add.type): String = + "+" - override def visit(lShift: LShift.type): String = - "<<" + override def visit(sub: Sub.type): String = + "-" - override def visit(rShift: RShift.type): String = - ">>" + override def visit(mult: Mult.type): String = + "*" - override def visit(bitOr: BitOr.type): String = - "|" + override def visit(matMult: MatMult.type): String = + "@" - override def visit(bitXor: BitXor.type): String = - "^" + override def visit(div: Div.type): String = + "/" - override def visit(bitAnd: BitAnd.type): String = - "&" + override def visit(mod: Mod.type): String = + "%" - override def visit(floorDiv: FloorDiv.type): String = - "//" + override def visit(pow: Pow.type): String = + "**" - override def visit(unaryop: iunaryop): String = ??? + override def visit(lShift: LShift.type): String = + "<<" - override def visit(invert: Invert.type): String = - "~" + override def visit(rShift: RShift.type): String = + ">>" - override def visit(not: Not.type): String = - "not" + override def visit(bitOr: BitOr.type): String = + "|" - override def visit(uAdd: UAdd.type): String = - "+" + override def visit(bitXor: BitXor.type): String = + "^" - override def visit(uSub: USub.type): String = - "-" + override def visit(bitAnd: BitAnd.type): String = + "&" - override def visit(arg: Arg): String = - arg.arg + arg.annotation.map(a => ": " + print(a)).getOrElse("") + override def visit(floorDiv: FloorDiv.type): String = + "//" - override def visit(arguments: Arguments): String = - var result = "" - var separatorString = "" - val combinedPosArgSize = arguments.posonlyargs.size + arguments.args.size - val defaultArgs = immutable.List.fill(combinedPosArgSize - arguments.defaults.size)(None) ++ - arguments.defaults.map(Option.apply) + override def visit(unaryop: iunaryop): String = ??? - if arguments.posonlyargs.nonEmpty then - val posOnlyArgsString = - arguments.posonlyargs - .zip(defaultArgs) - .map { case (arg, defaultOption) => - print(arg) + defaultOption.map(d => " = " + print(d)).getOrElse("") - } - .mkString("", ", ", ", /") + override def visit(invert: Invert.type): String = + "~" - result += posOnlyArgsString - separatorString = ", " + override def visit(not: Not.type): String = + "not" + + override def visit(uAdd: UAdd.type): String = + "+" + + override def visit(uSub: USub.type): String = + "-" + + override def visit(arg: Arg): String = + arg.arg + arg.annotation.map(a => ": " + print(a)).getOrElse("") + + override def visit(arguments: Arguments): String = + var result = "" + var separatorString = "" + val combinedPosArgSize = arguments.posonlyargs.size + arguments.args.size + val defaultArgs = immutable.List.fill(combinedPosArgSize - arguments.defaults.size)(None) ++ + arguments.defaults.map(Option.apply) + + if arguments.posonlyargs.nonEmpty then + val posOnlyArgsString = + arguments.posonlyargs + .zip(defaultArgs) + .map { case (arg, defaultOption) => + print(arg) + defaultOption.map(d => " = " + print(d)).getOrElse("") + } + .mkString("", ", ", ", /") + + result += posOnlyArgsString + separatorString = ", " + + if arguments.args.nonEmpty then + val defaultsForArgs = defaultArgs.drop(arguments.posonlyargs.size) + val argsString = + arguments.args + .zip(defaultsForArgs) + .map { case (arg, defaultOption) => + print(arg) + defaultOption.map(d => " = " + print(d)).getOrElse("") + } + .mkString(separatorString, ", ", "") + + result += argsString + separatorString = ", " + + arguments.vararg match + case Some(v) => + result += separatorString + result += "*" + print(v) + separatorString = ", " + case None if arguments.kwonlyargs.nonEmpty => + result += separatorString + result += "*" + separatorString = ", " + case None => + + if arguments.kwonlyargs.nonEmpty then + val kwOnlyArgsString = + arguments.kwonlyargs + .zip(arguments.kw_defaults) + .map { case (arg, defaultOption) => + print(arg) + defaultOption.map(d => " = " + print(d)).getOrElse("") + } + .mkString(separatorString, ", ", "") + + result += kwOnlyArgsString + separatorString = ", " + + arguments.kw_arg.foreach { k => + result += separatorString + result += "**" + print(k) + } + + result + end visit + + override def visit(withItem: Withitem): String = + print(withItem.context_expr) + withItem.optional_vars.map(o => " as " + print(o)).getOrElse( + "" + ) + + override def visit(matchCase: MatchCase): String = + "case " + print(matchCase.pattern) + matchCase.guard.map(g => " if " + print(g)).getOrElse( + "" + ) + ":" + + matchCase.body.map(printIndented).mkString(ls, ls, "") + + override def visit(matchValue: MatchValue): String = + print(matchValue.value) + + override def visit(matchSingleton: MatchSingleton): String = + print(matchSingleton.value) + + override def visit(matchSequence: MatchSequence): String = + matchSequence.patterns.map(print).mkString("[", ", ", "]") + + override def visit(matchMapping: MatchMapping): String = + "{" + matchMapping.keys + .zip(matchMapping.patterns) + .map { case (key, pattern) => + print(key) + ": " + print(pattern) + } + .mkString(", ") + + matchMapping.rest + .map { r => + val separatorString = + if matchMapping.keys.nonEmpty then + ", " + else + "" + separatorString + "**" + r + } + .getOrElse("") + "}" + + override def visit(matchClass: MatchClass): String = + val separatorString = + if matchClass.patterns.nonEmpty && matchClass.kwd_patterns.nonEmpty then + ", " + else + "" - if arguments.args.nonEmpty then - val defaultsForArgs = defaultArgs.drop(arguments.posonlyargs.size) - val argsString = - arguments.args - .zip(defaultsForArgs) - .map { case (arg, defaultOption) => - print(arg) + defaultOption.map(d => " = " + print(d)).getOrElse("") - } - .mkString(separatorString, ", ", "") - - result += argsString - separatorString = ", " - - arguments.vararg match - case Some(v) => - result += separatorString - result += "*" + print(v) - separatorString = ", " - case None if arguments.kwonlyargs.nonEmpty => - result += separatorString - result += "*" - separatorString = ", " - case None => + print(matchClass.cls) + + "(" + + matchClass.patterns.map(print).mkString(", ") + + separatorString + + matchClass.kwd_attrs + .zip(matchClass.kwd_patterns) + .map { case (name, pattern) => + name + " = " + print(pattern) + } + .mkString(", ") + + ")" - if arguments.kwonlyargs.nonEmpty then - val kwOnlyArgsString = - arguments.kwonlyargs - .zip(arguments.kw_defaults) - .map { case (arg, defaultOption) => - print(arg) + defaultOption.map(d => " = " + print(d)).getOrElse("") - } - .mkString(separatorString, ", ", "") - - result += kwOnlyArgsString - separatorString = ", " - - arguments.kw_arg.foreach { k => - result += separatorString - result += "**" + print(k) - } - - result - end visit - - override def visit(withItem: Withitem): String = - print(withItem.context_expr) + withItem.optional_vars.map(o => " as " + print(o)).getOrElse( - "" - ) + override def visit(matchStar: MatchStar): String = + "*" + matchStar.name.getOrElse("_") - override def visit(matchCase: MatchCase): String = - "case " + print(matchCase.pattern) + matchCase.guard.map(g => " if " + print(g)).getOrElse( - "" - ) + ":" + - matchCase.body.map(printIndented).mkString(ls, ls, "") + override def visit(matchAs: MatchAs): String = + matchAs.pattern match + case Some(pattern) => + print(pattern) + matchAs.name.map(name => " as " + name).getOrElse("") + case None => + matchAs.name.getOrElse("_") - override def visit(matchValue: MatchValue): String = - print(matchValue.value) + override def visit(matchOr: MatchOr): String = + matchOr.patterns.map(print).mkString(" | ") - override def visit(matchSingleton: MatchSingleton): String = - print(matchSingleton.value) + override def visit(comprehension: Comprehension): String = + val prefix = + if comprehension.is_async then + " async for " + else + " for " - override def visit(matchSequence: MatchSequence): String = - matchSequence.patterns.map(print).mkString("[", ", ", "]") + prefix + print(comprehension.target) + " in " + print(comprehension.iter) + + comprehension.ifs.map(i => " if " + print(i)).mkString("") - override def visit(matchMapping: MatchMapping): String = - "{" + matchMapping.keys - .zip(matchMapping.patterns) - .map { case (key, pattern) => - print(key) + ": " + print(pattern) - } - .mkString(", ") + - matchMapping.rest - .map { r => - val separatorString = - if matchMapping.keys.nonEmpty then - ", " - else - "" - separatorString + "**" + r - } - .getOrElse("") + "}" - - override def visit(matchClass: MatchClass): String = - val separatorString = - if matchClass.patterns.nonEmpty && matchClass.kwd_patterns.nonEmpty then - ", " - else - "" - - print(matchClass.cls) + - "(" + - matchClass.patterns.map(print).mkString(", ") + - separatorString + - matchClass.kwd_attrs - .zip(matchClass.kwd_patterns) - .map { case (name, pattern) => - name + " = " + print(pattern) - } - .mkString(", ") + - ")" - - override def visit(matchStar: MatchStar): String = - "*" + matchStar.name.getOrElse("_") - - override def visit(matchAs: MatchAs): String = - matchAs.pattern match - case Some(pattern) => - print(pattern) + matchAs.name.map(name => " as " + name).getOrElse("") - case None => - matchAs.name.getOrElse("_") - - override def visit(matchOr: MatchOr): String = - matchOr.patterns.map(print).mkString(" | ") - - override def visit(comprehension: Comprehension): String = - val prefix = - if comprehension.is_async then - " async for " - else - " for " - - prefix + print(comprehension.target) + " in " + print(comprehension.iter) + - comprehension.ifs.map(i => " if " + print(i)).mkString("") - - override def visit(typeIgnore: TypeIgnore): String = - typeIgnore.tag + override def visit(typeIgnore: TypeIgnore): String = + typeIgnore.tag end AstPrinter diff --git a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pythonparser/AstVisitor.scala b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pythonparser/AstVisitor.scala index 0b89902a..18d03da4 100644 --- a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pythonparser/AstVisitor.scala +++ b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pythonparser/AstVisitor.scala @@ -123,149 +123,149 @@ import io.appthreat.pythonparser.ast.{ import io.appthreat.pythonparser.ast.* trait AstVisitor[T]: - def visit(ast: iast): T + def visit(ast: iast): T - def visit(mod: imod): T - def visit(module: Module): T + def visit(mod: imod): T + def visit(module: Module): T - def visit(stmt: istmt): T - def visit(functionDef: FunctionDef): T - def visit(functionDef: AsyncFunctionDef): T - def visit(classDef: ClassDef): T - def visit(ret: Return): T - def visit(delete: Delete): T - def visit(assign: Assign): T - def visit(annAssign: AnnAssign): T - def visit(augAssign: AugAssign): T - def visit(forStmt: For): T - def visit(forStmt: AsyncFor): T - def visit(whileStmt: While): T - def visit(ifStmt: If): T - def visit(withStmt: With): T - def visit(withStmt: AsyncWith): T - def visit(matchStmt: Match): T - def visit(raise: Raise): T - def visit(tryStmt: Try): T - def visit(assert: Assert): T - def visit(importStmt: Import): T - def visit(importFrom: ImportFrom): T - def visit(global: Global): T - def visit(nonlocal: Nonlocal): T - def visit(expr: Expr): T - def visit(pass: Pass): T - def visit(break: Break): T - def visit(continue: Continue): T - def visit(raise: RaiseP2): T - def visit(errorStatement: ErrorStatement): T + def visit(stmt: istmt): T + def visit(functionDef: FunctionDef): T + def visit(functionDef: AsyncFunctionDef): T + def visit(classDef: ClassDef): T + def visit(ret: Return): T + def visit(delete: Delete): T + def visit(assign: Assign): T + def visit(annAssign: AnnAssign): T + def visit(augAssign: AugAssign): T + def visit(forStmt: For): T + def visit(forStmt: AsyncFor): T + def visit(whileStmt: While): T + def visit(ifStmt: If): T + def visit(withStmt: With): T + def visit(withStmt: AsyncWith): T + def visit(matchStmt: Match): T + def visit(raise: Raise): T + def visit(tryStmt: Try): T + def visit(assert: Assert): T + def visit(importStmt: Import): T + def visit(importFrom: ImportFrom): T + def visit(global: Global): T + def visit(nonlocal: Nonlocal): T + def visit(expr: Expr): T + def visit(pass: Pass): T + def visit(break: Break): T + def visit(continue: Continue): T + def visit(raise: RaiseP2): T + def visit(errorStatement: ErrorStatement): T - def visit(expr: iexpr): T - def visit(boolOp: BoolOp): T - def visit(namedExpr: NamedExpr): T - def visit(binOp: BinOp): T - def visit(unaryOp: UnaryOp): T - def visit(lambda: Lambda): T - def visit(ifExp: IfExp): T - def visit(dict: Dict): T - def visit(set: ast.Set): T - def visit(listComp: ListComp): T - def visit(setComp: SetComp): T - def visit(dictComp: DictComp): T - def visit(generatorExp: GeneratorExp): T - def visit(await: Await): T - def visit(yieldExpr: Yield): T - def visit(yieldFrom: YieldFrom): T - def visit(compare: Compare): T - def visit(call: Call): T - def visit(formattedValue: FormattedValue): T - def visit(joinedString: JoinedString): T - def visit(constant: Constant): T - def visit(attribute: Attribute): T - def visit(subscript: Subscript): T - def visit(starred: Starred): T - def visit(name: Name): T - def visit(list: ast.List): T - def visit(tuple: Tuple): T - def visit(slice: Slice): T - def visit(stringExpList: StringExpList): T + def visit(expr: iexpr): T + def visit(boolOp: BoolOp): T + def visit(namedExpr: NamedExpr): T + def visit(binOp: BinOp): T + def visit(unaryOp: UnaryOp): T + def visit(lambda: Lambda): T + def visit(ifExp: IfExp): T + def visit(dict: Dict): T + def visit(set: ast.Set): T + def visit(listComp: ListComp): T + def visit(setComp: SetComp): T + def visit(dictComp: DictComp): T + def visit(generatorExp: GeneratorExp): T + def visit(await: Await): T + def visit(yieldExpr: Yield): T + def visit(yieldFrom: YieldFrom): T + def visit(compare: Compare): T + def visit(call: Call): T + def visit(formattedValue: FormattedValue): T + def visit(joinedString: JoinedString): T + def visit(constant: Constant): T + def visit(attribute: Attribute): T + def visit(subscript: Subscript): T + def visit(starred: Starred): T + def visit(name: Name): T + def visit(list: ast.List): T + def visit(tuple: Tuple): T + def visit(slice: Slice): T + def visit(stringExpList: StringExpList): T - def visit(boolop: iboolop): T - def visit(and: And.type): T - def visit(or: Or.type): T + def visit(boolop: iboolop): T + def visit(and: And.type): T + def visit(or: Or.type): T - def visit(operator: ioperator): T - def visit(add: Add.type): T - def visit(sub: Sub.type): T - def visit(mult: Mult.type): T - def visit(matMult: MatMult.type): T - def visit(div: Div.type): T - def visit(mod: Mod.type): T - def visit(pow: Pow.type): T - def visit(lShift: LShift.type): T - def visit(rShift: RShift.type): T - def visit(bitOr: BitOr.type): T - def visit(bitXor: BitXor.type): T - def visit(bitAnd: BitAnd.type): T - def visit(floorDiv: FloorDiv.type): T + def visit(operator: ioperator): T + def visit(add: Add.type): T + def visit(sub: Sub.type): T + def visit(mult: Mult.type): T + def visit(matMult: MatMult.type): T + def visit(div: Div.type): T + def visit(mod: Mod.type): T + def visit(pow: Pow.type): T + def visit(lShift: LShift.type): T + def visit(rShift: RShift.type): T + def visit(bitOr: BitOr.type): T + def visit(bitXor: BitXor.type): T + def visit(bitAnd: BitAnd.type): T + def visit(floorDiv: FloorDiv.type): T - def visit(unaryop: iunaryop): T - def visit(invert: Invert.type): T - def visit(not: Not.type): T - def visit(uAdd: UAdd.type): T - def visit(uSub: USub.type): T + def visit(unaryop: iunaryop): T + def visit(invert: Invert.type): T + def visit(not: Not.type): T + def visit(uAdd: UAdd.type): T + def visit(uSub: USub.type): T - def visit(compop: icompop): T - def visit(eq: Eq.type): T - def visit(notEq: NotEq.type): T - def visit(lt: Lt.type): T - def visit(ltE: LtE.type): T - def visit(gt: Gt.type): T - def visit(gtE: GtE.type): T - def visit(is: Is.type): T - def visit(isNot: IsNot.type): T - def visit(in: In.type): T - def visit(notIn: NotIn.type): T + def visit(compop: icompop): T + def visit(eq: Eq.type): T + def visit(notEq: NotEq.type): T + def visit(lt: Lt.type): T + def visit(ltE: LtE.type): T + def visit(gt: Gt.type): T + def visit(gtE: GtE.type): T + def visit(is: Is.type): T + def visit(isNot: IsNot.type): T + def visit(in: In.type): T + def visit(notIn: NotIn.type): T - def visit(comprehension: Comprehension): T + def visit(comprehension: Comprehension): T - def visit(exceptHandler: ExceptHandler): T + def visit(exceptHandler: ExceptHandler): T - def visit(arguments: Arguments): T + def visit(arguments: Arguments): T - def visit(arg: Arg): T + def visit(arg: Arg): T - def visit(constant: iconstant): T - def visit(stringConstant: StringConstant): T - def visit(joinedStringConstant: JoinedStringConstant): T - def visit(boolConstant: BoolConstant): T - def visit(intConstant: IntConstant): T - def visit(intConstant: FloatConstant): T - def visit(imaginaryConstant: ImaginaryConstant): T - def visit(noneConstant: NoneConstant.type): T - def visit(ellipsisConstant: EllipsisConstant.type): T + def visit(constant: iconstant): T + def visit(stringConstant: StringConstant): T + def visit(joinedStringConstant: JoinedStringConstant): T + def visit(boolConstant: BoolConstant): T + def visit(intConstant: IntConstant): T + def visit(intConstant: FloatConstant): T + def visit(imaginaryConstant: ImaginaryConstant): T + def visit(noneConstant: NoneConstant.type): T + def visit(ellipsisConstant: EllipsisConstant.type): T - def visit(keyword: Keyword): T + def visit(keyword: Keyword): T - def visit(alias: Alias): T + def visit(alias: Alias): T - def visit(withItem: Withitem): T + def visit(withItem: Withitem): T - def visit(matchCase: MatchCase): T + def visit(matchCase: MatchCase): T - def visit(matchValue: MatchValue): T + def visit(matchValue: MatchValue): T - def visit(matchSingleton: MatchSingleton): T + def visit(matchSingleton: MatchSingleton): T - def visit(matchSequence: MatchSequence): T + def visit(matchSequence: MatchSequence): T - def visit(matchMapping: MatchMapping): T + def visit(matchMapping: MatchMapping): T - def visit(matchClass: MatchClass): T + def visit(matchClass: MatchClass): T - def visit(matchStar: MatchStar): T + def visit(matchStar: MatchStar): T - def visit(matchAs: MatchAs): T + def visit(matchAs: MatchAs): T - def visit(matchOr: MatchOr): T + def visit(matchOr: MatchOr): T - def visit(typeIgnore: TypeIgnore): T + def visit(typeIgnore: TypeIgnore): T end AstVisitor diff --git a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pythonparser/CharStreamImpl.scala b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pythonparser/CharStreamImpl.scala index b15c84e8..2542330e 100644 --- a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pythonparser/CharStreamImpl.scala +++ b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pythonparser/CharStreamImpl.scala @@ -5,203 +5,202 @@ import CharStreamImpl.{defaultInputBufferSize, defaultMinimumReadSize} import java.io.{IOException, InputStream, InputStreamReader} object CharStreamImpl: - private val defaultInputBufferSize = 4096 - private val defaultMinimumReadSize = 2048 + private val defaultInputBufferSize = 4096 + private val defaultMinimumReadSize = 2048 class CharStreamImpl(inputStream: InputStream, inputBufferSize: Int, minimumReadSize: Int) extends CharStream: - private val inputReader = new InputStreamReader(inputStream) - - private var inputBuffer = new Array[Char](inputBufferSize) - private var posToLine = new Array[Int](inputBufferSize) - private var posToColumn = new Array[Int](inputBufferSize) - private var readPos = 1 - private var writePos = 1 - private var tokenBeginPos = 1 - private var inputBufferOffset = 0 // From start of inputStream - private var tabSize = 1 - - // The first slot of inputBuffer, posToLine, posToColumn represents the last value - // from the previous chunk red from inputStream. For the very first chunk we - // initialise the values by hand so that requires no special cases in the implementation. - inputBuffer(0) = 'a' // We could have picked any char that is not '\n' or '\r'. - posToLine(0) = 1 - posToColumn(0) = 0 - - def this(inputStream: InputStream) = - this(inputStream, defaultInputBufferSize, defaultMinimumReadSize) - - def getBeginPos: Int = inputBufferOffset + tokenBeginPos - 1 - - private def fillBuffer(): Unit = - if readPos == writePos then - // No more data to read - if writePos == inputBuffer.length then - // No more space in inputBuffer - - val keepStartPos = tokenBeginPos - 1 - val charsToKeep = writePos - keepStartPos - if inputBuffer.length - charsToKeep < minimumReadSize then - // Resize buffer and move content to the front. - val newBufferLen = charsToKeep + minimumReadSize - - val newInputBuffer = new Array[Char](newBufferLen) - Array.copy(inputBuffer, keepStartPos, newInputBuffer, 0, charsToKeep) - inputBuffer = newInputBuffer - - val newPosToLine = new Array[Int](newBufferLen) - Array.copy(posToLine, keepStartPos, newPosToLine, 0, charsToKeep) - posToLine = newPosToLine - - val newPosToColumn = new Array[Int](newBufferLen) - Array.copy(posToColumn, keepStartPos, newPosToColumn, 0, charsToKeep) - posToColumn = newPosToColumn - else - // Enough space left to just move content to the front. - Array.copy(inputBuffer, keepStartPos, inputBuffer, 0, charsToKeep) - Array.copy(posToLine, keepStartPos, posToLine, 0, charsToKeep) - Array.copy(posToColumn, keepStartPos, posToColumn, 0, charsToKeep) - end if - writePos = charsToKeep - readPos = readPos - keepStartPos - inputBufferOffset += keepStartPos - tokenBeginPos = 1 - end if - - val charsRed = inputReader.read(inputBuffer, writePos, inputBuffer.length - writePos) - if charsRed != -1 then - writePos += charsRed - else - throw new IOException() - - private def lastRedPos: Int = - readPos - 1 - - private def updateLineAndColumn(pos: Int, char: Char, prevChar: Char): Unit = - - val newLine = - prevChar match - case '\n' => - true - case '\r' => - if char == '\n' then - false - else - true - case _ => - false - - if newLine then - posToLine(pos) = posToLine(pos - 1) + 1 - posToColumn(pos) = 1 + private val inputReader = new InputStreamReader(inputStream) + + private var inputBuffer = new Array[Char](inputBufferSize) + private var posToLine = new Array[Int](inputBufferSize) + private var posToColumn = new Array[Int](inputBufferSize) + private var readPos = 1 + private var writePos = 1 + private var tokenBeginPos = 1 + private var inputBufferOffset = 0 // From start of inputStream + private var tabSize = 1 + + // The first slot of inputBuffer, posToLine, posToColumn represents the last value + // from the previous chunk red from inputStream. For the very first chunk we + // initialise the values by hand so that requires no special cases in the implementation. + inputBuffer(0) = 'a' // We could have picked any char that is not '\n' or '\r'. + posToLine(0) = 1 + posToColumn(0) = 0 + + def this(inputStream: InputStream) = + this(inputStream, defaultInputBufferSize, defaultMinimumReadSize) + + def getBeginPos: Int = inputBufferOffset + tokenBeginPos - 1 + + private def fillBuffer(): Unit = + if readPos == writePos then + // No more data to read + if writePos == inputBuffer.length then + // No more space in inputBuffer + + val keepStartPos = tokenBeginPos - 1 + val charsToKeep = writePos - keepStartPos + if inputBuffer.length - charsToKeep < minimumReadSize then + // Resize buffer and move content to the front. + val newBufferLen = charsToKeep + minimumReadSize + + val newInputBuffer = new Array[Char](newBufferLen) + Array.copy(inputBuffer, keepStartPos, newInputBuffer, 0, charsToKeep) + inputBuffer = newInputBuffer + + val newPosToLine = new Array[Int](newBufferLen) + Array.copy(posToLine, keepStartPos, newPosToLine, 0, charsToKeep) + posToLine = newPosToLine + + val newPosToColumn = new Array[Int](newBufferLen) + Array.copy(posToColumn, keepStartPos, newPosToColumn, 0, charsToKeep) + posToColumn = newPosToColumn + else + // Enough space left to just move content to the front. + Array.copy(inputBuffer, keepStartPos, inputBuffer, 0, charsToKeep) + Array.copy(posToLine, keepStartPos, posToLine, 0, charsToKeep) + Array.copy(posToColumn, keepStartPos, posToColumn, 0, charsToKeep) + end if + writePos = charsToKeep + readPos = readPos - keepStartPos + inputBufferOffset += keepStartPos + tokenBeginPos = 1 + end if + + val charsRed = inputReader.read(inputBuffer, writePos, inputBuffer.length - writePos) + if charsRed != -1 then + writePos += charsRed else - posToLine(pos) = posToLine(pos - 1) - posToColumn(pos) = posToColumn(pos - 1) + 1 - - if char == '\t' then - posToColumn(pos) += -1 + (tabSize - (posToColumn(pos) % tabSize)) - end updateLineAndColumn - - /** Returns the next character from the selected input. The method of selecting the input is the - * responsibility of the class implementing this interface. Can throw any java.io.IOException. - */ - override def readChar(): Char = - fillBuffer() - val char = inputBuffer(readPos) - val prevChar = inputBuffer(lastRedPos) - updateLineAndColumn(readPos, char, prevChar) - readPos += 1 - char - - /** Returns the column position of the character last read. - * - * @deprecated - * @see - * #getEndColumn - */ - override def getColumn: Int = ??? - - /** Returns the line number of the character last read. - * - * @deprecated - * @see - * #getEndLine - */ - override def getLine: Int = ??? - - /** Returns the column number of the last character for current token (being matched after the - * last call to BeginTOken). - */ - override def getEndColumn: Int = - posToColumn(lastRedPos) - - /** Returns the line number of the last character for current token (being matched after the - * last call to BeginTOken). - */ - override def getEndLine: Int = - posToLine(lastRedPos) - - /** Returns the column number of the first character for current token (being matched after the - * last call to BeginTOken). - */ - override def getBeginColumn: Int = - posToColumn(tokenBeginPos) - - /** Returns the line number of the first character for current token (being matched after the - * last call to BeginTOken). - */ - override def getBeginLine: Int = - posToLine(tokenBeginPos) - - /** Backs up the input stream by amount steps. Lexer calls this method if it had already read - * some characters, but could not use them to match a (longer) token. So, they will be used - * again as the prefix of the next token and it is the implementation's responsibility to do - * this right. - */ - override def backup(amount: Int): Unit = - readPos -= amount - - /** Returns the next character that marks the beginning of the next token. All characters must - * remain in the buffer between two successive calls to this method to implement backup - * correctly. - */ - override def BeginToken(): Char = - tokenBeginPos = readPos - readChar() - - /** Returns a string made up of characters from the marked token beginning to the current buffer - * position. Implementations have the choice of returning anything that they want to. For - * example, for efficiency, one might decide to just return null, which is a valid - * implementation. - */ - override def GetImage(): String = - new String(inputBuffer, tokenBeginPos, readPos - tokenBeginPos) - - /** Returns an array of characters that make up the suffix of length 'len' for the currently - * matched token. This is used to build up the matched string for use in actions in the case of - * MORE. A simple and inefficient implementation of this is as follows : - * - * { String t = GetImage(); return t.substring(t.length() - len, t.length()).toCharArray(); } - */ - override def GetSuffix(len: Int): Array[Char] = - val suffix = new Array[Char](len) - Array.copy(inputBuffer, readPos - len, suffix, 0, len) - suffix - - /** The lexer calls this function to indicate that it is done with the stream and hence - * implementations can free any resources held by this class. Again, the body of this function - * can be just empty and it will not affect the lexer's operation. - */ - override def Done(): Unit = - inputReader.close() - - override def setTabSize(i: Int): Unit = - tabSize = i - - override def getTabSize: Int = - tabSize - - override def getTrackLineColumn: Boolean = ??? - - override def setTrackLineColumn(trackLineColumn: Boolean): Unit = ??? + throw new IOException() + + private def lastRedPos: Int = + readPos - 1 + + private def updateLineAndColumn(pos: Int, char: Char, prevChar: Char): Unit = + + val newLine = + prevChar match + case '\n' => + true + case '\r' => + if char == '\n' then + false + else + true + case _ => + false + + if newLine then + posToLine(pos) = posToLine(pos - 1) + 1 + posToColumn(pos) = 1 + else + posToLine(pos) = posToLine(pos - 1) + posToColumn(pos) = posToColumn(pos - 1) + 1 + + if char == '\t' then + posToColumn(pos) += -1 + (tabSize - (posToColumn(pos) % tabSize)) + end updateLineAndColumn + + /** Returns the next character from the selected input. The method of selecting the input is the + * responsibility of the class implementing this interface. Can throw any java.io.IOException. + */ + override def readChar(): Char = + fillBuffer() + val char = inputBuffer(readPos) + val prevChar = inputBuffer(lastRedPos) + updateLineAndColumn(readPos, char, prevChar) + readPos += 1 + char + + /** Returns the column position of the character last read. + * + * @deprecated + * @see + * #getEndColumn + */ + override def getColumn: Int = ??? + + /** Returns the line number of the character last read. + * + * @deprecated + * @see + * #getEndLine + */ + override def getLine: Int = ??? + + /** Returns the column number of the last character for current token (being matched after the + * last call to BeginTOken). + */ + override def getEndColumn: Int = + posToColumn(lastRedPos) + + /** Returns the line number of the last character for current token (being matched after the last + * call to BeginTOken). + */ + override def getEndLine: Int = + posToLine(lastRedPos) + + /** Returns the column number of the first character for current token (being matched after the + * last call to BeginTOken). + */ + override def getBeginColumn: Int = + posToColumn(tokenBeginPos) + + /** Returns the line number of the first character for current token (being matched after the last + * call to BeginTOken). + */ + override def getBeginLine: Int = + posToLine(tokenBeginPos) + + /** Backs up the input stream by amount steps. Lexer calls this method if it had already read some + * characters, but could not use them to match a (longer) token. So, they will be used again as + * the prefix of the next token and it is the implementation's responsibility to do this right. + */ + override def backup(amount: Int): Unit = + readPos -= amount + + /** Returns the next character that marks the beginning of the next token. All characters must + * remain in the buffer between two successive calls to this method to implement backup + * correctly. + */ + override def BeginToken(): Char = + tokenBeginPos = readPos + readChar() + + /** Returns a string made up of characters from the marked token beginning to the current buffer + * position. Implementations have the choice of returning anything that they want to. For + * example, for efficiency, one might decide to just return null, which is a valid + * implementation. + */ + override def GetImage(): String = + new String(inputBuffer, tokenBeginPos, readPos - tokenBeginPos) + + /** Returns an array of characters that make up the suffix of length 'len' for the currently + * matched token. This is used to build up the matched string for use in actions in the case of + * MORE. A simple and inefficient implementation of this is as follows : + * + * { String t = GetImage(); return t.substring(t.length() - len, t.length()).toCharArray(); } + */ + override def GetSuffix(len: Int): Array[Char] = + val suffix = new Array[Char](len) + Array.copy(inputBuffer, readPos - len, suffix, 0, len) + suffix + + /** The lexer calls this function to indicate that it is done with the stream and hence + * implementations can free any resources held by this class. Again, the body of this function + * can be just empty and it will not affect the lexer's operation. + */ + override def Done(): Unit = + inputReader.close() + + override def setTabSize(i: Int): Unit = + tabSize = i + + override def getTabSize: Int = + tabSize + + override def getTrackLineColumn: Boolean = ??? + + override def setTrackLineColumn(trackLineColumn: Boolean): Unit = ??? end CharStreamImpl diff --git a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pythonparser/PyParser.scala b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pythonparser/PyParser.scala index 276ba03a..8a016e11 100644 --- a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pythonparser/PyParser.scala +++ b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pythonparser/PyParser.scala @@ -9,18 +9,18 @@ import java.nio.charset.StandardCharsets import scala.jdk.CollectionConverters.* class PyParser: - private var pythonParser: PythonParser = scala.compiletime.uninitialized + private var pythonParser: PythonParser = scala.compiletime.uninitialized - def parse(code: String): iast = - parse(new ByteArrayInputStream(code.getBytes(StandardCharsets.UTF_8))) + def parse(code: String): iast = + parse(new ByteArrayInputStream(code.getBytes(StandardCharsets.UTF_8))) - def parse(inputStream: InputStream): iast = - pythonParser = new PythonParser(new CharStreamImpl(inputStream)) - // We start in INDENT_CHECK lexer state because we want to detect indentations - // also for the first line. - pythonParser.token_source.SwitchTo(PythonParserConstants.INDENT_CHECK) - val module = pythonParser.module() - module + def parse(inputStream: InputStream): iast = + pythonParser = new PythonParser(new CharStreamImpl(inputStream)) + // We start in INDENT_CHECK lexer state because we want to detect indentations + // also for the first line. + pythonParser.token_source.SwitchTo(PythonParserConstants.INDENT_CHECK) + val module = pythonParser.module() + module - def errors: Iterable[ErrorStatement] = - pythonParser.getErrors.asScala + def errors: Iterable[ErrorStatement] = + pythonParser.getErrors.asScala diff --git a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pythonparser/ast/Ast.scala b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pythonparser/ast/Ast.scala index 122d84af..fb2d3690 100644 --- a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pythonparser/ast/Ast.scala +++ b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pythonparser/ast/Ast.scala @@ -37,7 +37,7 @@ import scala.jdk.CollectionConverters.* // AST root trait /////////////////////////////////////////////////////////////////////////////////////////////////// trait iast: - def accept[T](visitor: AstVisitor[T]): T + def accept[T](visitor: AstVisitor[T]): T /////////////////////////////////////////////////////////////////////////////////////////////////// // AST module classes @@ -45,10 +45,10 @@ trait iast: trait imod extends iast case class Module(stmts: CollType[istmt], type_ignores: CollType[TypeIgnore]) extends imod: - def this(stmts: util.ArrayList[istmt], type_ignores: util.ArrayList[TypeIgnore]) = - this(stmts.asScala, type_ignores.asScala) - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + def this(stmts: util.ArrayList[istmt], type_ignores: util.ArrayList[TypeIgnore]) = + this(stmts.asScala, type_ignores.asScala) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) /////////////////////////////////////////////////////////////////////////////////////////////////// // AST statement classes @@ -65,26 +65,26 @@ case class FunctionDef( type_comment: Option[String], attributeProvider: AttributeProvider ) extends istmt: - def this( - name: String, - args: Arguments, - body: util.ArrayList[istmt], - decorator_list: util.ArrayList[iexpr], - returns: iexpr, - type_comment: String, - attributeProvider: AttributeProvider - ) = - this( - name, - args, - body.asScala, - decorator_list.asScala, - Option(returns), - Option(type_comment), - attributeProvider - ) - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + def this( + name: String, + args: Arguments, + body: util.ArrayList[istmt], + decorator_list: util.ArrayList[iexpr], + returns: iexpr, + type_comment: String, + attributeProvider: AttributeProvider + ) = + this( + name, + args, + body.asScala, + decorator_list.asScala, + Option(returns), + Option(type_comment), + attributeProvider + ) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) end FunctionDef case class AsyncFunctionDef( @@ -96,26 +96,26 @@ case class AsyncFunctionDef( type_comment: Option[String], attributeProvider: AttributeProvider ) extends istmt: - def this( - name: String, - args: Arguments, - body: util.ArrayList[istmt], - decorator_list: util.ArrayList[iexpr], - returns: iexpr, - type_comment: String, - attributeProvider: AttributeProvider - ) = - this( - name, - args, - body.asScala, - decorator_list.asScala, - Option(returns), - Option(type_comment), - attributeProvider - ) - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + def this( + name: String, + args: Arguments, + body: util.ArrayList[istmt], + decorator_list: util.ArrayList[iexpr], + returns: iexpr, + type_comment: String, + attributeProvider: AttributeProvider + ) = + this( + name, + args, + body.asScala, + decorator_list.asScala, + Option(returns), + Option(type_comment), + attributeProvider + ) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) end AsyncFunctionDef case class ClassDef( @@ -126,37 +126,37 @@ case class ClassDef( decorator_list: CollType[iexpr], attributeProvider: AttributeProvider ) extends istmt: - def this( - name: String, - bases: util.ArrayList[iexpr], - keywords: util.ArrayList[Keyword], - body: util.ArrayList[istmt], - decorator_list: util.ArrayList[iexpr], - attributeProvider: AttributeProvider - ) = - this( - name, - bases.asScala, - keywords.asScala, - body.asScala, - decorator_list.asScala, - attributeProvider - ) - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + def this( + name: String, + bases: util.ArrayList[iexpr], + keywords: util.ArrayList[Keyword], + body: util.ArrayList[istmt], + decorator_list: util.ArrayList[iexpr], + attributeProvider: AttributeProvider + ) = + this( + name, + bases.asScala, + keywords.asScala, + body.asScala, + decorator_list.asScala, + attributeProvider + ) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) end ClassDef case class Return(value: Option[iexpr], attributeProvider: AttributeProvider) extends istmt: - def this(value: iexpr, attributeProvider: AttributeProvider) = - this(Option(value), attributeProvider) - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + def this(value: iexpr, attributeProvider: AttributeProvider) = + this(Option(value), attributeProvider) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) case class Delete(targets: CollType[iexpr], attributeProvider: AttributeProvider) extends istmt: - def this(targets: util.ArrayList[iexpr], attributeProvider: AttributeProvider) = - this(targets.asScala, attributeProvider) - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + def this(targets: util.ArrayList[iexpr], attributeProvider: AttributeProvider) = + this(targets.asScala, attributeProvider) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) case class Assign( targets: CollType[iexpr], @@ -164,10 +164,10 @@ case class Assign( typeComment: Option[String], attributeProvider: AttributeProvider ) extends istmt: - def this(targets: util.ArrayList[iexpr], value: iexpr, attributeProvider: AttributeProvider) = - this(targets.asScala, value, None, attributeProvider) - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + def this(targets: util.ArrayList[iexpr], value: iexpr, attributeProvider: AttributeProvider) = + this(targets.asScala, value, None, attributeProvider) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) case class AugAssign( target: iexpr, @@ -175,8 +175,8 @@ case class AugAssign( value: iexpr, attributeProvider: AttributeProvider ) extends istmt: - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) case class AnnAssign( target: iexpr, @@ -185,16 +185,16 @@ case class AnnAssign( simple: Boolean, attributeProvider: AttributeProvider ) extends istmt: - def this( - target: iexpr, - annotation: iexpr, - value: iexpr, - simple: Boolean, - attributeProvider: AttributeProvider - ) = - this(target, annotation, Option(value), simple, attributeProvider) - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + def this( + target: iexpr, + annotation: iexpr, + value: iexpr, + simple: Boolean, + attributeProvider: AttributeProvider + ) = + this(target, annotation, Option(value), simple, attributeProvider) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) case class For( target: iexpr, @@ -204,17 +204,17 @@ case class For( type_comment: Option[String], attributeProvider: AttributeProvider ) extends istmt: - def this( - target: iexpr, - iter: iexpr, - body: util.ArrayList[istmt], - orelse: util.ArrayList[istmt], - type_comment: String, - attributeProvider: AttributeProvider - ) = - this(target, iter, body.asScala, orelse.asScala, Option(type_comment), attributeProvider) - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + def this( + target: iexpr, + iter: iexpr, + body: util.ArrayList[istmt], + orelse: util.ArrayList[istmt], + type_comment: String, + attributeProvider: AttributeProvider + ) = + this(target, iter, body.asScala, orelse.asScala, Option(type_comment), attributeProvider) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) end For case class AsyncFor( @@ -225,17 +225,17 @@ case class AsyncFor( type_comment: Option[String], attributeProvider: AttributeProvider ) extends istmt: - def this( - target: iexpr, - iter: iexpr, - body: util.ArrayList[istmt], - orelse: util.ArrayList[istmt], - type_comment: String, - attributeProvider: AttributeProvider - ) = - this(target, iter, body.asScala, orelse.asScala, Option(type_comment), attributeProvider) - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + def this( + target: iexpr, + iter: iexpr, + body: util.ArrayList[istmt], + orelse: util.ArrayList[istmt], + type_comment: String, + attributeProvider: AttributeProvider + ) = + this(target, iter, body.asScala, orelse.asScala, Option(type_comment), attributeProvider) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) end AsyncFor case class While( @@ -244,15 +244,15 @@ case class While( orelse: CollType[istmt], attributeProvider: AttributeProvider ) extends istmt: - def this( - test: iexpr, - body: util.ArrayList[istmt], - orelse: util.ArrayList[istmt], - attributeProvider: AttributeProvider - ) = - this(test, body.asScala, orelse.asScala, attributeProvider) - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + def this( + test: iexpr, + body: util.ArrayList[istmt], + orelse: util.ArrayList[istmt], + attributeProvider: AttributeProvider + ) = + this(test, body.asScala, orelse.asScala, attributeProvider) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) case class If( test: iexpr, @@ -260,15 +260,15 @@ case class If( orelse: CollType[istmt], attributeProvider: AttributeProvider ) extends istmt: - def this( - test: iexpr, - body: util.ArrayList[istmt], - orelse: util.ArrayList[istmt], - attributeProvider: AttributeProvider - ) = - this(test, body.asScala, orelse.asScala, attributeProvider) - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + def this( + test: iexpr, + body: util.ArrayList[istmt], + orelse: util.ArrayList[istmt], + attributeProvider: AttributeProvider + ) = + this(test, body.asScala, orelse.asScala, attributeProvider) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) case class With( items: CollType[Withitem], @@ -276,15 +276,15 @@ case class With( type_comment: Option[String], attributeProvider: AttributeProvider ) extends istmt: - def this( - items: util.ArrayList[Withitem], - body: util.ArrayList[istmt], - type_comment: String, - attributeProvider: AttributeProvider - ) = - this(items.asScala, body.asScala, Option(type_comment), attributeProvider) - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + def this( + items: util.ArrayList[Withitem], + body: util.ArrayList[istmt], + type_comment: String, + attributeProvider: AttributeProvider + ) = + this(items.asScala, body.asScala, Option(type_comment), attributeProvider) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) case class AsyncWith( items: CollType[Withitem], @@ -292,29 +292,29 @@ case class AsyncWith( type_comment: Option[String], attributeProvider: AttributeProvider ) extends istmt: - def this( - items: util.ArrayList[Withitem], - body: util.ArrayList[istmt], - type_comment: String, - attributeProvider: AttributeProvider - ) = - this(items.asScala, body.asScala, Option(type_comment), attributeProvider) - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + def this( + items: util.ArrayList[Withitem], + body: util.ArrayList[istmt], + type_comment: String, + attributeProvider: AttributeProvider + ) = + this(items.asScala, body.asScala, Option(type_comment), attributeProvider) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) case class Match(subject: iexpr, cases: CollType[MatchCase], attributeProvider: AttributeProvider) extends istmt: - def this(subject: iexpr, cases: util.List[MatchCase], attributeProvider: AttributeProvider) = - this(subject, cases.asScala, attributeProvider) - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + def this(subject: iexpr, cases: util.List[MatchCase], attributeProvider: AttributeProvider) = + this(subject, cases.asScala, attributeProvider) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) case class Raise(exc: Option[iexpr], cause: Option[iexpr], attributeProvider: AttributeProvider) extends istmt: - def this(exc: iexpr, cause: iexpr, attributeProvider: AttributeProvider) = - this(Option(exc), Option(cause), attributeProvider) - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + def this(exc: iexpr, cause: iexpr, attributeProvider: AttributeProvider) = + this(Option(exc), Option(cause), attributeProvider) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) case class Try( body: CollType[istmt], @@ -323,29 +323,29 @@ case class Try( finalbody: CollType[istmt], attributeProvider: AttributeProvider ) extends istmt: - def this( - body: util.ArrayList[istmt], - handlers: util.ArrayList[ExceptHandler], - orelse: util.ArrayList[istmt], - finalbody: util.ArrayList[istmt], - attributeProvider: AttributeProvider - ) = - this(body.asScala, handlers.asScala, orelse.asScala, finalbody.asScala, attributeProvider) - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + def this( + body: util.ArrayList[istmt], + handlers: util.ArrayList[ExceptHandler], + orelse: util.ArrayList[istmt], + finalbody: util.ArrayList[istmt], + attributeProvider: AttributeProvider + ) = + this(body.asScala, handlers.asScala, orelse.asScala, finalbody.asScala, attributeProvider) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) case class Assert(test: iexpr, msg: Option[iexpr], attributeProvider: AttributeProvider) extends istmt: - def this(test: iexpr, msg: iexpr, attributeProvider: AttributeProvider) = - this(test, Option(msg), attributeProvider) - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + def this(test: iexpr, msg: iexpr, attributeProvider: AttributeProvider) = + this(test, Option(msg), attributeProvider) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) case class Import(names: CollType[Alias], attributeProvider: AttributeProvider) extends istmt: - def this(names: util.ArrayList[Alias], attributeProvider: AttributeProvider) = - this(names.asScala, attributeProvider) - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + def this(names: util.ArrayList[Alias], attributeProvider: AttributeProvider) = + this(names.asScala, attributeProvider) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) case class ImportFrom( module: Option[String], @@ -353,45 +353,45 @@ case class ImportFrom( level: Int, attributeProvider: AttributeProvider ) extends istmt: - def this( - module: String, - names: util.ArrayList[Alias], - level: Int, - attributeProvider: AttributeProvider - ) = - this(Option(module), names.asScala, level, attributeProvider) - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + def this( + module: String, + names: util.ArrayList[Alias], + level: Int, + attributeProvider: AttributeProvider + ) = + this(Option(module), names.asScala, level, attributeProvider) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) case class Global(names: CollType[String], attributeProvider: AttributeProvider) extends istmt: - def this(names: util.ArrayList[String], attributeProvider: AttributeProvider) = - this(names.asScala, attributeProvider) - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + def this(names: util.ArrayList[String], attributeProvider: AttributeProvider) = + this(names.asScala, attributeProvider) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) case class Nonlocal(names: CollType[String], attributeProvider: AttributeProvider) extends istmt: - def this(names: util.ArrayList[String], attributeProvider: AttributeProvider) = - this(names.asScala, attributeProvider) - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + def this(names: util.ArrayList[String], attributeProvider: AttributeProvider) = + this(names.asScala, attributeProvider) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) case class Expr(value: iexpr, attributeProvider: AttributeProvider) extends istmt: - def this(value: iexpr) = - this(value, value.attributeProvider) - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + def this(value: iexpr) = + this(value, value.attributeProvider) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) case class Pass(attributeProvider: AttributeProvider) extends istmt: - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) case class Break(attributeProvider: AttributeProvider) extends istmt: - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) case class Continue(attributeProvider: AttributeProvider) extends istmt: - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) // This is the python2 raise statement. // It is different enough from the python3 version to justify an @@ -406,17 +406,17 @@ case class RaiseP2( tback: Option[iexpr], attributeProvider: AttributeProvider ) extends istmt: - def this(typ: iexpr, inst: iexpr, tback: iexpr, attributeProvider: AttributeProvider) = - this(Option(typ), Option(inst), Option(tback), attributeProvider) - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + def this(typ: iexpr, inst: iexpr, tback: iexpr, attributeProvider: AttributeProvider) = + this(Option(typ), Option(inst), Option(tback), attributeProvider) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) // This statement is not part of the CPython AST definition and // was added to represent parse errors inline with valid AST // statements. case class ErrorStatement(exception: Exception, attributeProvider: AttributeProvider) extends istmt: - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) /////////////////////////////////////////////////////////////////////////////////////////////////// // AST expression classes @@ -425,82 +425,82 @@ sealed trait iexpr extends iast with iattributes case class BoolOp(op: iboolop, values: CollType[iexpr], attributeProvider: AttributeProvider) extends iexpr: - def this(op: iboolop, values: util.ArrayList[iexpr], attributeProvider: AttributeProvider) = - this(op, values.asScala, attributeProvider) - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + def this(op: iboolop, values: util.ArrayList[iexpr], attributeProvider: AttributeProvider) = + this(op, values.asScala, attributeProvider) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) case class NamedExpr(target: iexpr, value: iexpr, attributeProvider: AttributeProvider) extends iexpr: - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) case class BinOp(left: iexpr, op: ioperator, right: iexpr, attributeProvider: AttributeProvider) extends iexpr: - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) case class UnaryOp(op: iunaryop, operand: iexpr, attributeProvider: AttributeProvider) extends iexpr: - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) case class Lambda(args: Arguments, body: iexpr, attributeProvider: AttributeProvider) extends iexpr: - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) case class IfExp(test: iexpr, body: iexpr, orelse: iexpr, attributeProvider: AttributeProvider) extends iexpr: - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) case class Dict( keys: CollType[Option[iexpr]], values: CollType[iexpr], attributeProvider: AttributeProvider ) extends iexpr: - def this( - keys: util.ArrayList[iexpr], - values: util.ArrayList[iexpr], - attributeProvider: AttributeProvider - ) = - this(keys.asScala.map(Option.apply), values.asScala, attributeProvider) - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + def this( + keys: util.ArrayList[iexpr], + values: util.ArrayList[iexpr], + attributeProvider: AttributeProvider + ) = + this(keys.asScala.map(Option.apply), values.asScala, attributeProvider) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) case class Set(elts: CollType[iexpr], attributeProvider: AttributeProvider) extends iexpr: - def this(elts: util.ArrayList[iexpr], attributeProvider: AttributeProvider) = - this(elts.asScala, attributeProvider) - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + def this(elts: util.ArrayList[iexpr], attributeProvider: AttributeProvider) = + this(elts.asScala, attributeProvider) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) case class ListComp( elt: iexpr, generators: CollType[Comprehension], attributeProvider: AttributeProvider ) extends iexpr: - def this( - elt: iexpr, - generators: util.ArrayList[Comprehension], - attributeProvider: AttributeProvider - ) = - this(elt, generators.asScala, attributeProvider) - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + def this( + elt: iexpr, + generators: util.ArrayList[Comprehension], + attributeProvider: AttributeProvider + ) = + this(elt, generators.asScala, attributeProvider) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) case class SetComp( elt: iexpr, generators: CollType[Comprehension], attributeProvider: AttributeProvider ) extends iexpr: - def this( - elt: iexpr, - generators: util.ArrayList[Comprehension], - attributeProvider: AttributeProvider - ) = - this(elt, generators.asScala, attributeProvider) - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + def this( + elt: iexpr, + generators: util.ArrayList[Comprehension], + attributeProvider: AttributeProvider + ) = + this(elt, generators.asScala, attributeProvider) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) case class DictComp( key: iexpr, @@ -508,45 +508,45 @@ case class DictComp( generators: CollType[Comprehension], attributeProvider: AttributeProvider ) extends iexpr: - def this( - key: iexpr, - value: iexpr, - generators: util.ArrayList[Comprehension], - attributeProvider: AttributeProvider - ) = - this(key, value, generators.asScala, attributeProvider) - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + def this( + key: iexpr, + value: iexpr, + generators: util.ArrayList[Comprehension], + attributeProvider: AttributeProvider + ) = + this(key, value, generators.asScala, attributeProvider) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) case class GeneratorExp( elt: iexpr, generators: CollType[Comprehension], attributeProvider: AttributeProvider ) extends iexpr: - def this( - elt: iexpr, - generators: util.ArrayList[Comprehension], - attributeProvider: AttributeProvider - ) = - this(elt, generators.asScala, attributeProvider) - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + def this( + elt: iexpr, + generators: util.ArrayList[Comprehension], + attributeProvider: AttributeProvider + ) = + this(elt, generators.asScala, attributeProvider) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) case class Await(value: iexpr, attributeProvider: AttributeProvider) extends iexpr: - def this(value: iexpr) = - this(value, value.attributeProvider) - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + def this(value: iexpr) = + this(value, value.attributeProvider) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) case class Yield(value: Option[iexpr], attributeProvider: AttributeProvider) extends iexpr: - def this(value: iexpr, attributeProvider: AttributeProvider) = - this(Option(value), attributeProvider) - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + def this(value: iexpr, attributeProvider: AttributeProvider) = + this(Option(value), attributeProvider) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) case class YieldFrom(value: iexpr, attributeProvider: AttributeProvider) extends iexpr: - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) case class Compare( left: iexpr, @@ -554,16 +554,16 @@ case class Compare( comparators: CollType[iexpr], attributeProvider: AttributeProvider ) extends iexpr: - def this( - left: iexpr, - ops: util.ArrayList[icompop], - comparators: util.ArrayList[iexpr], - attributeProvider: AttributeProvider - ) = - this(left, ops.asScala, comparators.asScala, attributeProvider) + def this( + left: iexpr, + ops: util.ArrayList[icompop], + comparators: util.ArrayList[iexpr], + attributeProvider: AttributeProvider + ) = + this(left, ops.asScala, comparators.asScala, attributeProvider) - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) case class Call( func: iexpr, @@ -571,16 +571,16 @@ case class Call( keywords: CollType[Keyword], attributeProvider: AttributeProvider ) extends iexpr: - def this( - func: iexpr, - args: util.ArrayList[iexpr], - keywords: util.ArrayList[Keyword], - attributeProvider: AttributeProvider - ) = - this(func, args.asScala, keywords.asScala, attributeProvider) + def this( + func: iexpr, + args: util.ArrayList[iexpr], + keywords: util.ArrayList[Keyword], + attributeProvider: AttributeProvider + ) = + this(func, args.asScala, keywords.asScala, attributeProvider) - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) // In addition to the CPython version of this class we also stored // whether the value expression was followed by "=" in "equalSign". @@ -594,16 +594,16 @@ case class FormattedValue( equalSign: Boolean, attributeProvider: AttributeProvider ) extends iexpr: - def this( - value: iexpr, - conversion: Int, - format_spec: String, - equalSign: Boolean, - attributeProvider: AttributeProvider - ) = - this(value, conversion, Option(format_spec), equalSign, attributeProvider) - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + def this( + value: iexpr, + conversion: Int, + format_spec: String, + equalSign: Boolean, + attributeProvider: AttributeProvider + ) = + this(value, conversion, Option(format_spec), equalSign, attributeProvider) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) // In addition to the CPython version of this class we have the fields // "quote" which stores the kind of quote used and "prefix" @@ -614,49 +614,49 @@ case class JoinedString( prefix: String, attributeProvider: AttributeProvider ) extends iexpr: - def this( - values: util.List[iexpr], - quote: String, - prefix: String, - attributeProvider: AttributeProvider - ) = - this(values.asScala, quote, prefix, attributeProvider) - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + def this( + values: util.List[iexpr], + quote: String, + prefix: String, + attributeProvider: AttributeProvider + ) = + this(values.asScala, quote, prefix, attributeProvider) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) case class Constant(value: iconstant, attributeProvider: AttributeProvider) extends iexpr: - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) case class Attribute(value: iexpr, attr: String, attributeProvider: AttributeProvider) extends iexpr: - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) case class Subscript(value: iexpr, slice: iexpr, attributeProvider: AttributeProvider) extends iexpr: - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) case class Starred(value: iexpr, attributeProvider: AttributeProvider) extends iexpr: - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) case class Name(id: String, attributeProvider: AttributeProvider) extends iexpr: - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) case class List(elts: CollType[iexpr], attributeProvider: AttributeProvider) extends iexpr: - def this(elts: util.ArrayList[iexpr], attributeProvider: AttributeProvider) = - this(elts.asScala, attributeProvider) - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + def this(elts: util.ArrayList[iexpr], attributeProvider: AttributeProvider) = + this(elts.asScala, attributeProvider) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) case class Tuple(elts: CollType[iexpr], attributeProvider: AttributeProvider) extends iexpr: - def this(elts: util.ArrayList[iexpr], attributeProvider: AttributeProvider) = - this(elts.asScala, attributeProvider) - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + def this(elts: util.ArrayList[iexpr], attributeProvider: AttributeProvider) = + this(elts.asScala, attributeProvider) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) case class Slice( lower: Option[iexpr], @@ -664,10 +664,10 @@ case class Slice( step: Option[iexpr], attributeProvider: AttributeProvider ) extends iexpr: - def this(lower: iexpr, upper: iexpr, step: iexpr, attributeProvider: AttributeProvider) = - this(Option(lower), Option(upper), Option(step), attributeProvider) - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + def this(lower: iexpr, upper: iexpr, step: iexpr, attributeProvider: AttributeProvider) = + this(Option(lower), Option(upper), Option(step), attributeProvider) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) // This class is not part of the CPython AST definition at // https://docs.python.org/3/library/ast.html @@ -676,11 +676,11 @@ case class Slice( // A StringExpList must always have at least 2 elements and its elements must // be either a Constant which contains a StringConstant or a JoinedString. case class StringExpList(elts: CollType[iexpr], attributeProvider: AttributeProvider) extends iexpr: - assert(elts.size >= 2) - def this(elts: util.ArrayList[iexpr]) = - this(elts.asScala, elts.asScala.head.attributeProvider) - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + assert(elts.size >= 2) + def this(elts: util.ArrayList[iexpr]) = + this(elts.asScala, elts.asScala.head.attributeProvider) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) /////////////////////////////////////////////////////////////////////////////////////////////////// // AST boolop classes @@ -688,12 +688,12 @@ case class StringExpList(elts: CollType[iexpr], attributeProvider: AttributeProv sealed trait iboolop extends iast object And extends iboolop: - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) case object Or extends iboolop: - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) /////////////////////////////////////////////////////////////////////////////////////////////////// // AST operator classes @@ -701,44 +701,44 @@ case object Or extends iboolop: sealed trait ioperator extends iast case object Add extends ioperator: - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) case object Sub extends ioperator: - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) case object Mult extends ioperator: - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) case object MatMult extends ioperator: - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) case object Div extends ioperator: - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) case object Mod extends ioperator: - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) case object Pow extends ioperator: - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) case object LShift extends ioperator: - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) case object RShift extends ioperator: - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) case object BitOr extends ioperator: - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) case object BitXor extends ioperator: - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) case object BitAnd extends ioperator: - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) case object FloorDiv extends ioperator: - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) /////////////////////////////////////////////////////////////////////////////////////////////////// // AST unaryop classes @@ -746,20 +746,20 @@ case object FloorDiv extends ioperator: sealed trait iunaryop extends iast case object Invert extends iunaryop: - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) case object Not extends iunaryop: - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) case object UAdd extends iunaryop: - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) case object USub extends iunaryop: - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) /////////////////////////////////////////////////////////////////////////////////////////////////// // AST compop classes @@ -767,45 +767,45 @@ case object USub extends iunaryop: sealed trait icompop extends iast case object Eq extends icompop: - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) case object NotEq extends icompop: - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) case object Lt extends icompop: - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) case object LtE extends icompop: - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) case object Gt extends icompop: - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) case object GtE extends icompop: - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) case object Is extends icompop: - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) case object IsNot extends icompop: - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) case object In extends icompop: - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) case object NotIn extends icompop: - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) /////////////////////////////////////////////////////////////////////////////////////////////////// // AST comprehension classes /////////////////////////////////////////////////////////////////////////////////////////////////// case class Comprehension(target: iexpr, iter: iexpr, ifs: CollType[iexpr], is_async: Boolean) extends iast: - def this(target: iexpr, iter: iexpr, ifs: util.ArrayList[iexpr], is_async: Boolean) = - this(target, iter, ifs.asScala, is_async) - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + def this(target: iexpr, iter: iexpr, ifs: util.ArrayList[iexpr], is_async: Boolean) = + this(target, iter, ifs.asScala, is_async) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) /////////////////////////////////////////////////////////////////////////////////////////////////// // AST exceptHandler classes @@ -817,15 +817,15 @@ case class ExceptHandler( attributeProvider: AttributeProvider ) extends iast with iattributes: - def this( - typ: iexpr, - name: String, - body: util.ArrayList[istmt], - attributeProvider: AttributeProvider - ) = - this(Option(typ), Option(name), body.asScala, attributeProvider) - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + def this( + typ: iexpr, + name: String, + body: util.ArrayList[istmt], + attributeProvider: AttributeProvider + ) = + this(Option(typ), Option(name), body.asScala, attributeProvider) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) /////////////////////////////////////////////////////////////////////////////////////////////////// // AST arguments classes @@ -846,26 +846,26 @@ case class Arguments( kw_arg: Option[Arg], defaults: CollType[iexpr] ) extends iast: - def this( - posonlyargs: util.List[Arg], - args: util.List[Arg], - vararg: Arg, - kwonlyargs: util.List[Arg], - kw_defaults: util.List[iexpr], - kw_arg: Arg, - defaults: util.List[iexpr] - ) = - this( - posonlyargs.asScala, - args.asScala, - Option(vararg), - kwonlyargs.asScala, - kw_defaults.asScala.map(Option.apply), - Option(kw_arg), - defaults.asScala - ) - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + def this( + posonlyargs: util.List[Arg], + args: util.List[Arg], + vararg: Arg, + kwonlyargs: util.List[Arg], + kw_defaults: util.List[iexpr], + kw_arg: Arg, + defaults: util.List[iexpr] + ) = + this( + posonlyargs.asScala, + args.asScala, + Option(vararg), + kwonlyargs.asScala, + kw_defaults.asScala.map(Option.apply), + Option(kw_arg), + defaults.asScala + ) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) end Arguments /////////////////////////////////////////////////////////////////////////////////////////////////// @@ -878,15 +878,15 @@ case class Arg( attributeProvider: AttributeProvider ) extends iast with iattributes: - def this( - arg: String, - annotation: iexpr, - type_comment: String, - attributeProvider: AttributeProvider - ) = - this(arg, Option(annotation), Option(type_comment), attributeProvider) - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + def this( + arg: String, + annotation: iexpr, + type_comment: String, + attributeProvider: AttributeProvider + ) = + this(arg, Option(annotation), Option(type_comment), attributeProvider) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) /////////////////////////////////////////////////////////////////////////////////////////////////// // AST keyword classes @@ -894,38 +894,38 @@ case class Arg( case class Keyword(arg: Option[String], value: iexpr, attributeProvider: AttributeProvider) extends iast with iattributes: - def this(arg: String, value: iexpr, attributeProvider: AttributeProvider) = - this(Option(arg), value, attributeProvider) - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + def this(arg: String, value: iexpr, attributeProvider: AttributeProvider) = + this(Option(arg), value, attributeProvider) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) /////////////////////////////////////////////////////////////////////////////////////////////////// // AST alias classes /////////////////////////////////////////////////////////////////////////////////////////////////// case class Alias(name: String, asName: Option[String]) extends iast: - def this(name: String, asName: String) = - this(name, Option(asName)) - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + def this(name: String, asName: String) = + this(name, Option(asName)) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) /////////////////////////////////////////////////////////////////////////////////////////////////// // AST withitem classes /////////////////////////////////////////////////////////////////////////////////////////////////// case class Withitem(context_expr: iexpr, optional_vars: Option[iexpr]) extends iast: - def this(context_expr: iexpr, optional_vars: iexpr) = - this(context_expr, Option(optional_vars)) - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + def this(context_expr: iexpr, optional_vars: iexpr) = + this(context_expr, Option(optional_vars)) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) /////////////////////////////////////////////////////////////////////////////////////////////////// // AST match_case classes /////////////////////////////////////////////////////////////////////////////////////////////////// case class MatchCase(pattern: ipattern, guard: Option[iexpr], body: CollType[istmt]) extends iast: - def this(pattern: ipattern, guard: iexpr, body: util.List[istmt]) = - this(pattern, Option(guard), body.asScala) - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + def this(pattern: ipattern, guard: iexpr, body: util.List[istmt]) = + this(pattern, Option(guard), body.asScala) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) /////////////////////////////////////////////////////////////////////////////////////////////////// // AST pattern classes @@ -933,19 +933,19 @@ case class MatchCase(pattern: ipattern, guard: Option[iexpr], body: CollType[ist sealed trait ipattern extends iast with iattributes case class MatchValue(value: iexpr, attributeProvider: AttributeProvider) extends ipattern: - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) case class MatchSingleton(value: iconstant, attributeProvider: AttributeProvider) extends ipattern: - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) case class MatchSequence(patterns: CollType[ipattern], attributeProvider: AttributeProvider) extends ipattern: - def this(patterns: util.List[ipattern], attributeProvider: AttributeProvider) = - this(patterns.asScala, attributeProvider) - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + def this(patterns: util.List[ipattern], attributeProvider: AttributeProvider) = + this(patterns.asScala, attributeProvider) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) case class MatchMapping( keys: CollType[iexpr], @@ -953,15 +953,15 @@ case class MatchMapping( rest: Option[String], attributeProvider: AttributeProvider ) extends ipattern: - def this( - keys: util.List[iexpr], - patterns: util.List[ipattern], - rest: String, - attributeProvider: AttributeProvider - ) = - this(keys.asScala, patterns.asScala, Option(rest), attributeProvider) - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + def this( + keys: util.List[iexpr], + patterns: util.List[ipattern], + rest: String, + attributeProvider: AttributeProvider + ) = + this(keys.asScala, patterns.asScala, Option(rest), attributeProvider) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) case class MatchClass( cls: iexpr, @@ -970,58 +970,58 @@ case class MatchClass( kwd_patterns: CollType[ipattern], attributeProvider: AttributeProvider ) extends ipattern: - def this( - cls: iexpr, - patterns: util.List[ipattern], - kwd_attrs: util.List[String], - kwd_patterns: util.List[ipattern], - attributeProvider: AttributeProvider - ) = - this(cls, patterns.asScala, kwd_attrs.asScala, kwd_patterns.asScala, attributeProvider) - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + def this( + cls: iexpr, + patterns: util.List[ipattern], + kwd_attrs: util.List[String], + kwd_patterns: util.List[ipattern], + attributeProvider: AttributeProvider + ) = + this(cls, patterns.asScala, kwd_attrs.asScala, kwd_patterns.asScala, attributeProvider) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) case class MatchStar(name: Option[String], attributeProvider: AttributeProvider) extends ipattern: - def this(name: String, attributeProvider: AttributeProvider) = - this(Option(name), attributeProvider) - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + def this(name: String, attributeProvider: AttributeProvider) = + this(Option(name), attributeProvider) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) case class MatchAs( pattern: Option[ipattern], name: Option[String], attributeProvider: AttributeProvider ) extends ipattern: - def this(pattern: ipattern, name: String, attributeProvider: AttributeProvider) = - this(Option(pattern), Option(name), attributeProvider) - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + def this(pattern: ipattern, name: String, attributeProvider: AttributeProvider) = + this(Option(pattern), Option(name), attributeProvider) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) case class MatchOr(patterns: CollType[ipattern], attributeProvider: AttributeProvider) extends ipattern: - def this(patterns: util.List[ipattern], attributeProvider: AttributeProvider) = - this(patterns.asScala, attributeProvider) - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + def this(patterns: util.List[ipattern], attributeProvider: AttributeProvider) = + this(patterns.asScala, attributeProvider) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) /////////////////////////////////////////////////////////////////////////////////////////////////// // AST type_ignore classes /////////////////////////////////////////////////////////////////////////////////////////////////// case class TypeIgnore(lineno: Int, tag: String) extends iast: - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) /////////////////////////////////////////////////////////////////////////////////////////////////// // AST attributes classes /////////////////////////////////////////////////////////////////////////////////////////////////// trait iattributes: - val attributeProvider: AttributeProvider - def lineno: Int = attributeProvider.lineno - def col_offset: Int = attributeProvider.col_offset - def input_offset: Int = attributeProvider.input_offset - def end_lineno: Int = attributeProvider.end_lineno - def end_col_offset: Int = attributeProvider.end_col_offset - def end_input_offset: Int = attributeProvider.end_input_offset + val attributeProvider: AttributeProvider + def lineno: Int = attributeProvider.lineno + def col_offset: Int = attributeProvider.col_offset + def input_offset: Int = attributeProvider.input_offset + def end_lineno: Int = attributeProvider.end_lineno + def end_col_offset: Int = attributeProvider.end_col_offset + def end_input_offset: Int = attributeProvider.end_input_offset /////////////////////////////////////////////////////////////////////////////////////////////////// // AST constant classes @@ -1029,26 +1029,26 @@ trait iattributes: sealed trait iconstant extends iast case class StringConstant(value: String, quote: String, prefix: String) extends iconstant: - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) case class JoinedStringConstant(value: String) extends iconstant: - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) case class BoolConstant(value: Boolean) extends iconstant: - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) case class IntConstant(value: String) extends iconstant: - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) case class FloatConstant(value: String) extends iconstant: - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) case class ImaginaryConstant(value: String) extends iconstant: - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) case object NoneConstant extends iconstant: - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) case object EllipsisConstant extends iconstant: - override def accept[T](visitor: AstVisitor[T]): T = - visitor.visit(this) + override def accept[T](visitor: AstVisitor[T]): T = + visitor.visit(this) diff --git a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pythonparser/ast/AttributeProvider.scala b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pythonparser/ast/AttributeProvider.scala index f8b63e19..16213f94 100644 --- a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pythonparser/ast/AttributeProvider.scala +++ b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pythonparser/ast/AttributeProvider.scala @@ -3,50 +3,50 @@ package io.appthreat.pythonparser.ast import io.appthreat.pythonparser.Token trait AttributeProvider: - def lineno: Int - def col_offset: Int - def input_offset: Int - def end_lineno: Int - def end_col_offset: Int - def end_input_offset: Int + def lineno: Int + def col_offset: Int + def input_offset: Int + def end_lineno: Int + def end_col_offset: Int + def end_input_offset: Int - override def toString: String = - s"$lineno,$col_offset,$end_lineno,$end_col_offset" + override def toString: String = + s"$lineno,$col_offset,$end_lineno,$end_col_offset" class TokenAttributeProvider(startToken: Token, endToken: Token) extends AttributeProvider: - override def lineno: Int = - startToken.beginLine + override def lineno: Int = + startToken.beginLine - override def col_offset: Int = - startToken.beginColumn + override def col_offset: Int = + startToken.beginColumn - override def input_offset: Int = - startToken.startPos + override def input_offset: Int = + startToken.startPos - override def end_lineno: Int = - endToken.endLine + override def end_lineno: Int = + endToken.endLine - override def end_col_offset: Int = - endToken.endColumn + override def end_col_offset: Int = + endToken.endColumn - override def end_input_offset: Int = - endToken.endPos + override def end_input_offset: Int = + endToken.endPos class NodeAttributeProvider(astNode: iattributes, endToken: Token) extends AttributeProvider: - override def lineno: Int = - astNode.lineno + override def lineno: Int = + astNode.lineno - override def col_offset: Int = - astNode.col_offset + override def col_offset: Int = + astNode.col_offset - override def input_offset: Int = - astNode.input_offset + override def input_offset: Int = + astNode.input_offset - override def end_lineno: Int = - endToken.endLine + override def end_lineno: Int = + endToken.endLine - override def end_col_offset: Int = - endToken.endColumn + override def end_col_offset: Int = + endToken.endColumn - override def end_input_offset: Int = - endToken.endPos + override def end_input_offset: Int = + endToken.endPos diff --git a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pythonparser/ast/package.scala b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pythonparser/ast/package.scala index d2b5152d..a8afbe3c 100644 --- a/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pythonparser/ast/package.scala +++ b/platform/frontends/pysrc2cpg/src/main/scala/io/appthreat/pythonparser/ast/package.scala @@ -3,4 +3,4 @@ package io.appthreat.pythonparser import scala.collection.mutable package object ast: - type CollType[T] = mutable.Seq[T] + type CollType[T] = mutable.Seq[T] diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/Ast.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/Ast.scala index 0e566733..af39267f 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/Ast.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/Ast.scala @@ -10,75 +10,75 @@ import overflowdb.SchemaViolationException case class AstEdge(src: NewNode, dst: NewNode) enum ValidationMode: - case Enabled, Disabled + case Enabled, Disabled object Ast: - private val logger = LoggerFactory.getLogger(getClass) - - def apply(node: NewNode)(implicit withSchemaValidation: ValidationMode): Ast = - Ast(Vector.empty :+ node) - def apply()(implicit withSchemaValidation: ValidationMode): Ast = new Ast(Vector.empty) - - /** Copy nodes/edges of given `AST` into the given `diffGraph`. - */ - def storeInDiffGraph(ast: Ast, diffGraph: DiffGraphBuilder): Unit = - - setOrderWhereNotSet(ast) - - ast.nodes.foreach { node => - diffGraph.addNode(node) - } - ast.edges.foreach { edge => - diffGraph.addEdge(edge.src, edge.dst, EdgeTypes.AST) - } - ast.conditionEdges.foreach { edge => - diffGraph.addEdge(edge.src, edge.dst, EdgeTypes.CONDITION) - } - ast.receiverEdges.foreach { edge => - diffGraph.addEdge(edge.src, edge.dst, EdgeTypes.RECEIVER) - } - ast.refEdges.foreach { edge => - diffGraph.addEdge(edge.src, edge.dst, EdgeTypes.REF) - } - - ast.argEdges.foreach { edge => - diffGraph.addEdge(edge.src, edge.dst, EdgeTypes.ARGUMENT) - } + private val logger = LoggerFactory.getLogger(getClass) + + def apply(node: NewNode)(implicit withSchemaValidation: ValidationMode): Ast = + Ast(Vector.empty :+ node) + def apply()(implicit withSchemaValidation: ValidationMode): Ast = new Ast(Vector.empty) + + /** Copy nodes/edges of given `AST` into the given `diffGraph`. + */ + def storeInDiffGraph(ast: Ast, diffGraph: DiffGraphBuilder): Unit = + + setOrderWhereNotSet(ast) + + ast.nodes.foreach { node => + diffGraph.addNode(node) + } + ast.edges.foreach { edge => + diffGraph.addEdge(edge.src, edge.dst, EdgeTypes.AST) + } + ast.conditionEdges.foreach { edge => + diffGraph.addEdge(edge.src, edge.dst, EdgeTypes.CONDITION) + } + ast.receiverEdges.foreach { edge => + diffGraph.addEdge(edge.src, edge.dst, EdgeTypes.RECEIVER) + } + ast.refEdges.foreach { edge => + diffGraph.addEdge(edge.src, edge.dst, EdgeTypes.REF) + } + + ast.argEdges.foreach { edge => + diffGraph.addEdge(edge.src, edge.dst, EdgeTypes.ARGUMENT) + } + + ast.bindsEdges.foreach { edge => + diffGraph.addEdge(edge.src, edge.dst, EdgeTypes.BINDS) + } + end storeInDiffGraph + + def neighbourValidation(src: NewNode, dst: NewNode, edge: String)(implicit + withSchemaValidation: ValidationMode + ): Unit = + if + withSchemaValidation == ValidationMode.Enabled && + !(src.isValidOutNeighbor(edge, dst) && dst.isValidInNeighbor(edge, src)) + then + throw new SchemaViolationException( + s"Malformed AST detected: (${src.label()}) -[$edge]-> (${dst.label()}) violates the schema." + ) - ast.bindsEdges.foreach { edge => - diffGraph.addEdge(edge.src, edge.dst, EdgeTypes.BINDS) - } - end storeInDiffGraph - - def neighbourValidation(src: NewNode, dst: NewNode, edge: String)(implicit - withSchemaValidation: ValidationMode - ): Unit = - if - withSchemaValidation == ValidationMode.Enabled && - !(src.isValidOutNeighbor(edge, dst) && dst.isValidInNeighbor(edge, src)) - then - throw new SchemaViolationException( - s"Malformed AST detected: (${src.label()}) -[$edge]-> (${dst.label()}) violates the schema." - ) - - /** For all `order` fields that are unset, derive the `order` field automatically by determining - * the position of the child among its siblings. - */ - private def setOrderWhereNotSet(ast: Ast): Unit = - ast.root.collect { case r: AstNodeNew => - if r.order == PropertyDefaults.Order then - r.order = 1 - } - val siblings = ast.edges.groupBy(_.src).map { case (_, edgeToChild) => - edgeToChild.map(_.dst) - } - siblings.foreach { children => - children.zipWithIndex.collect { case (c: AstNodeNew, i) => - if c.order == PropertyDefaults.Order then - c.order = i + 1 - } + /** For all `order` fields that are unset, derive the `order` field automatically by determining + * the position of the child among its siblings. + */ + private def setOrderWhereNotSet(ast: Ast): Unit = + ast.root.collect { case r: AstNodeNew => + if r.order == PropertyDefaults.Order then + r.order = 1 + } + val siblings = ast.edges.groupBy(_.src).map { case (_, edgeToChild) => + edgeToChild.map(_.dst) + } + siblings.foreach { children => + children.zipWithIndex.collect { case (c: AstNodeNew, i) => + if c.order == PropertyDefaults.Order then + c.order = i + 1 } + } end Ast case class Ast( @@ -93,160 +93,160 @@ case class Ast( argEdges: collection.Seq[AstEdge] = Vector.empty )(implicit withSchemaValidation: ValidationMode = ValidationMode.Disabled): - def root: Option[NewNode] = nodes.headOption - - def rightMostLeaf: Option[NewNode] = nodes.lastOption - - /** AST that results when adding `other` as a child to this AST. `other` is connected to this - * AST's root node. - */ - def withChild(other: Ast): Ast = - Ast( - nodes ++ other.nodes, - edges = edges ++ other.edges ++ root.toList.flatMap(r => - other.root.toList.map { rc => - Ast.neighbourValidation(r, rc, EdgeTypes.AST) - AstEdge(r, rc) - } - ), - conditionEdges = conditionEdges ++ other.conditionEdges, - argEdges = argEdges ++ other.argEdges, - receiverEdges = receiverEdges ++ other.receiverEdges, - refEdges = refEdges ++ other.refEdges, - bindsEdges = bindsEdges ++ other.bindsEdges - ) + def root: Option[NewNode] = nodes.headOption - def merge(other: Ast): Ast = - Ast( - nodes ++ other.nodes, - edges = edges ++ other.edges, - conditionEdges = conditionEdges ++ other.conditionEdges, - argEdges = argEdges ++ other.argEdges, - receiverEdges = receiverEdges ++ other.receiverEdges, - refEdges = refEdges ++ other.refEdges, - bindsEdges = bindsEdges ++ other.bindsEdges - ) - - /** AST that results when adding all ASTs in `asts` as children, that is, connecting them to the - * root node of this AST. - */ - def withChildren(asts: collection.Seq[Ast]): Ast = - if asts.isEmpty then - this - else - // we do this iteratively as a recursive solution which will fail with - // a StackOverflowException if there are too many elements in .tail. - var ast = withChild(asts.head) - asts.tail.foreach(c => ast = ast.withChild(c)) - ast - - def withConditionEdge(src: NewNode, dst: NewNode): Ast = - Ast.neighbourValidation(src, dst, EdgeTypes.CONDITION) - this.copy(conditionEdges = conditionEdges ++ List(AstEdge(src, dst))) - - def withRefEdge(src: NewNode, dst: NewNode): Ast = - Ast.neighbourValidation(src, dst, EdgeTypes.REF) - this.copy(refEdges = refEdges ++ List(AstEdge(src, dst))) - - def withBindsEdge(src: NewNode, dst: NewNode): Ast = - Ast.neighbourValidation(src, dst, EdgeTypes.BINDS) - this.copy(bindsEdges = bindsEdges ++ List(AstEdge(src, dst))) - - def withReceiverEdge(src: NewNode, dst: NewNode): Ast = - Ast.neighbourValidation(src, dst, EdgeTypes.RECEIVER) - this.copy(receiverEdges = receiverEdges ++ List(AstEdge(src, dst))) - - def withArgEdge(src: NewNode, dst: NewNode): Ast = - Ast.neighbourValidation(src, dst, EdgeTypes.ARGUMENT) - this.copy(argEdges = argEdges ++ List(AstEdge(src, dst))) - - def withArgEdges(src: NewNode, dsts: Seq[NewNode]): Ast = - dsts.foreach(dst => Ast.neighbourValidation(src, dst, EdgeTypes.ARGUMENT)) - this.copy(argEdges = argEdges ++ dsts.map(AstEdge(src, _))) - - def withArgEdges(src: NewNode, dsts: Seq[NewNode], argIndexStart: Int): Ast = - var index = argIndexStart - this.copy(argEdges = argEdges ++ dsts.map { dst => - addArgumentIndex(dst, index) - index += 1 - Ast.neighbourValidation(src, dst, EdgeTypes.ARGUMENT) - AstEdge(src, dst) - }) - private def addArgumentIndex(node: NewNode, argIndex: Int): Unit = node match - case n: NewBlock => n.argumentIndex = argIndex - case n: NewCall => n.argumentIndex = argIndex - case n: NewFieldIdentifier => n.argumentIndex = argIndex - case n: NewIdentifier => n.argumentIndex = argIndex - case n: NewMethodRef => n.argumentIndex = argIndex - case n: NewTypeRef => n.argumentIndex = argIndex - case n: NewUnknown => n.argumentIndex = argIndex - case n: NewControlStructure => n.argumentIndex = argIndex - case n: NewLiteral => n.argumentIndex = argIndex - case n: NewReturn => n.argumentIndex = argIndex - - def withConditionEdges(src: NewNode, dsts: List[NewNode]): Ast = - dsts.foreach(dst => Ast.neighbourValidation(src, dst, EdgeTypes.CONDITION)) - this.copy(conditionEdges = conditionEdges ++ dsts.map(AstEdge(src, _))) - - def withRefEdges(src: NewNode, dsts: List[NewNode]): Ast = - dsts.foreach(dst => Ast.neighbourValidation(src, dst, EdgeTypes.REF)) - this.copy(refEdges = refEdges ++ dsts.map(AstEdge(src, _))) - - def withBindsEdges(src: NewNode, dsts: List[NewNode]): Ast = - dsts.foreach(dst => Ast.neighbourValidation(src, dst, EdgeTypes.BINDS)) - this.copy(bindsEdges = bindsEdges ++ dsts.map(AstEdge(src, _))) - - def withReceiverEdges(src: NewNode, dsts: List[NewNode]): Ast = - dsts.foreach(dst => Ast.neighbourValidation(src, dst, EdgeTypes.RECEIVER)) - this.copy(receiverEdges = receiverEdges ++ dsts.map(AstEdge(src, _))) - - /** Returns a deep copy of the sub tree rooted in `node`. If `order` is set, then the `order` - * and `argumentIndex` fields of the new root node are set to `order`. If `replacementNode` is - * set, then this replaces `node` in the new copy. - */ - def subTreeCopy( - node: AstNodeNew, - argIndex: Int = -1, - replacementNode: Option[AstNodeNew] = None - ): Ast = - val newNode = replacementNode match - case Some(n) => n - case None => node.copy - if argIndex != -1 then - // newNode.order = argIndex - newNode match - case expr: ExpressionNew => - expr.argumentIndex = argIndex - case _ => - - val astChildren = edges.filter(_.src == node).map(_.dst) - val newChildren = astChildren.map { x => - this.subTreeCopy(x.asInstanceOf[AstNodeNew]) - } + def rightMostLeaf: Option[NewNode] = nodes.lastOption - val oldToNew = astChildren.zip(newChildren).map { case (old, n) => old -> n.root.get }.toMap - def newIfExists(x: NewNode) = - oldToNew.getOrElse(x, x) - - val newArgEdges = - argEdges.filter(_.src == node).map(x => AstEdge(newNode, newIfExists(x.dst))) - val newConditionEdges = - conditionEdges.filter(_.src == node).map(x => AstEdge(newNode, newIfExists(x.dst))) - val newRefEdges = - refEdges.filter(_.src == node).map(x => AstEdge(newNode, newIfExists(x.dst))) - val newBindsEdges = - bindsEdges.filter(_.src == node).map(x => AstEdge(newNode, newIfExists(x.dst))) - val newReceiverEdges = - receiverEdges.filter(_.src == node).map(x => AstEdge(newNode, newIfExists(x.dst))) - - Ast(newNode) - .copy( - argEdges = newArgEdges, - conditionEdges = newConditionEdges, - refEdges = newRefEdges, - bindsEdges = newBindsEdges, - receiverEdges = newReceiverEdges - ) - .withChildren(newChildren) - end subTreeCopy + /** AST that results when adding `other` as a child to this AST. `other` is connected to this + * AST's root node. + */ + def withChild(other: Ast): Ast = + Ast( + nodes ++ other.nodes, + edges = edges ++ other.edges ++ root.toList.flatMap(r => + other.root.toList.map { rc => + Ast.neighbourValidation(r, rc, EdgeTypes.AST) + AstEdge(r, rc) + } + ), + conditionEdges = conditionEdges ++ other.conditionEdges, + argEdges = argEdges ++ other.argEdges, + receiverEdges = receiverEdges ++ other.receiverEdges, + refEdges = refEdges ++ other.refEdges, + bindsEdges = bindsEdges ++ other.bindsEdges + ) + + def merge(other: Ast): Ast = + Ast( + nodes ++ other.nodes, + edges = edges ++ other.edges, + conditionEdges = conditionEdges ++ other.conditionEdges, + argEdges = argEdges ++ other.argEdges, + receiverEdges = receiverEdges ++ other.receiverEdges, + refEdges = refEdges ++ other.refEdges, + bindsEdges = bindsEdges ++ other.bindsEdges + ) + + /** AST that results when adding all ASTs in `asts` as children, that is, connecting them to the + * root node of this AST. + */ + def withChildren(asts: collection.Seq[Ast]): Ast = + if asts.isEmpty then + this + else + // we do this iteratively as a recursive solution which will fail with + // a StackOverflowException if there are too many elements in .tail. + var ast = withChild(asts.head) + asts.tail.foreach(c => ast = ast.withChild(c)) + ast + + def withConditionEdge(src: NewNode, dst: NewNode): Ast = + Ast.neighbourValidation(src, dst, EdgeTypes.CONDITION) + this.copy(conditionEdges = conditionEdges ++ List(AstEdge(src, dst))) + + def withRefEdge(src: NewNode, dst: NewNode): Ast = + Ast.neighbourValidation(src, dst, EdgeTypes.REF) + this.copy(refEdges = refEdges ++ List(AstEdge(src, dst))) + + def withBindsEdge(src: NewNode, dst: NewNode): Ast = + Ast.neighbourValidation(src, dst, EdgeTypes.BINDS) + this.copy(bindsEdges = bindsEdges ++ List(AstEdge(src, dst))) + + def withReceiverEdge(src: NewNode, dst: NewNode): Ast = + Ast.neighbourValidation(src, dst, EdgeTypes.RECEIVER) + this.copy(receiverEdges = receiverEdges ++ List(AstEdge(src, dst))) + + def withArgEdge(src: NewNode, dst: NewNode): Ast = + Ast.neighbourValidation(src, dst, EdgeTypes.ARGUMENT) + this.copy(argEdges = argEdges ++ List(AstEdge(src, dst))) + + def withArgEdges(src: NewNode, dsts: Seq[NewNode]): Ast = + dsts.foreach(dst => Ast.neighbourValidation(src, dst, EdgeTypes.ARGUMENT)) + this.copy(argEdges = argEdges ++ dsts.map(AstEdge(src, _))) + + def withArgEdges(src: NewNode, dsts: Seq[NewNode], argIndexStart: Int): Ast = + var index = argIndexStart + this.copy(argEdges = argEdges ++ dsts.map { dst => + addArgumentIndex(dst, index) + index += 1 + Ast.neighbourValidation(src, dst, EdgeTypes.ARGUMENT) + AstEdge(src, dst) + }) + private def addArgumentIndex(node: NewNode, argIndex: Int): Unit = node match + case n: NewBlock => n.argumentIndex = argIndex + case n: NewCall => n.argumentIndex = argIndex + case n: NewFieldIdentifier => n.argumentIndex = argIndex + case n: NewIdentifier => n.argumentIndex = argIndex + case n: NewMethodRef => n.argumentIndex = argIndex + case n: NewTypeRef => n.argumentIndex = argIndex + case n: NewUnknown => n.argumentIndex = argIndex + case n: NewControlStructure => n.argumentIndex = argIndex + case n: NewLiteral => n.argumentIndex = argIndex + case n: NewReturn => n.argumentIndex = argIndex + + def withConditionEdges(src: NewNode, dsts: List[NewNode]): Ast = + dsts.foreach(dst => Ast.neighbourValidation(src, dst, EdgeTypes.CONDITION)) + this.copy(conditionEdges = conditionEdges ++ dsts.map(AstEdge(src, _))) + + def withRefEdges(src: NewNode, dsts: List[NewNode]): Ast = + dsts.foreach(dst => Ast.neighbourValidation(src, dst, EdgeTypes.REF)) + this.copy(refEdges = refEdges ++ dsts.map(AstEdge(src, _))) + + def withBindsEdges(src: NewNode, dsts: List[NewNode]): Ast = + dsts.foreach(dst => Ast.neighbourValidation(src, dst, EdgeTypes.BINDS)) + this.copy(bindsEdges = bindsEdges ++ dsts.map(AstEdge(src, _))) + + def withReceiverEdges(src: NewNode, dsts: List[NewNode]): Ast = + dsts.foreach(dst => Ast.neighbourValidation(src, dst, EdgeTypes.RECEIVER)) + this.copy(receiverEdges = receiverEdges ++ dsts.map(AstEdge(src, _))) + + /** Returns a deep copy of the sub tree rooted in `node`. If `order` is set, then the `order` and + * `argumentIndex` fields of the new root node are set to `order`. If `replacementNode` is set, + * then this replaces `node` in the new copy. + */ + def subTreeCopy( + node: AstNodeNew, + argIndex: Int = -1, + replacementNode: Option[AstNodeNew] = None + ): Ast = + val newNode = replacementNode match + case Some(n) => n + case None => node.copy + if argIndex != -1 then + // newNode.order = argIndex + newNode match + case expr: ExpressionNew => + expr.argumentIndex = argIndex + case _ => + + val astChildren = edges.filter(_.src == node).map(_.dst) + val newChildren = astChildren.map { x => + this.subTreeCopy(x.asInstanceOf[AstNodeNew]) + } + + val oldToNew = astChildren.zip(newChildren).map { case (old, n) => old -> n.root.get }.toMap + def newIfExists(x: NewNode) = + oldToNew.getOrElse(x, x) + + val newArgEdges = + argEdges.filter(_.src == node).map(x => AstEdge(newNode, newIfExists(x.dst))) + val newConditionEdges = + conditionEdges.filter(_.src == node).map(x => AstEdge(newNode, newIfExists(x.dst))) + val newRefEdges = + refEdges.filter(_.src == node).map(x => AstEdge(newNode, newIfExists(x.dst))) + val newBindsEdges = + bindsEdges.filter(_.src == node).map(x => AstEdge(newNode, newIfExists(x.dst))) + val newReceiverEdges = + receiverEdges.filter(_.src == node).map(x => AstEdge(newNode, newIfExists(x.dst))) + + Ast(newNode) + .copy( + argEdges = newArgEdges, + conditionEdges = newConditionEdges, + refEdges = newRefEdges, + bindsEdges = newBindsEdges, + receiverEdges = newReceiverEdges + ) + .withChildren(newChildren) + end subTreeCopy end Ast diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/AstCreatorBase.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/AstCreatorBase.scala index 320f6b94..bc294535 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/AstCreatorBase.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/AstCreatorBase.scala @@ -9,329 +9,329 @@ import io.shiftleft.semanticcpg.language.types.structure.NamespaceTraversal import overflowdb.BatchedUpdate.DiffGraphBuilder abstract class AstCreatorBase(filename: String)(implicit withSchemaValidation: ValidationMode): - val diffGraph: DiffGraphBuilder = new DiffGraphBuilder - private val closureKeyPool = new IntervalKeyPool(first = 0, last = Long.MaxValue) - def createAst(): DiffGraphBuilder + val diffGraph: DiffGraphBuilder = new DiffGraphBuilder + private val closureKeyPool = new IntervalKeyPool(first = 0, last = Long.MaxValue) + def createAst(): DiffGraphBuilder - /** Create a global namespace block for the given `filename` - */ - def globalNamespaceBlock(): NewNamespaceBlock = - val name = NamespaceTraversal.globalNamespaceName - val fullName = MetaDataPass.getGlobalNamespaceBlockFullName(Some(filename)) - NewNamespaceBlock() - .name(name) - .fullName(fullName) - .filename(filename) - .order(1) + /** Create a global namespace block for the given `filename` + */ + def globalNamespaceBlock(): NewNamespaceBlock = + val name = NamespaceTraversal.globalNamespaceName + val fullName = MetaDataPass.getGlobalNamespaceBlockFullName(Some(filename)) + NewNamespaceBlock() + .name(name) + .fullName(fullName) + .filename(filename) + .order(1) - /** Creates an AST that represents an annotation, including its content (annotation parameter - * assignments). - */ - def annotationAst(annotation: NewAnnotation, children: Seq[Ast]): Ast = - val annotationAst = Ast(annotation) - annotationAst.withChildren(children) + /** Creates an AST that represents an annotation, including its content (annotation parameter + * assignments). + */ + def annotationAst(annotation: NewAnnotation, children: Seq[Ast]): Ast = + val annotationAst = Ast(annotation) + annotationAst.withChildren(children) - /** Creates an AST that represents an annotation assignment with a name for the assigned value, - * its overall code, and the respective assignment AST. - */ - def annotationAssignmentAst( - assignmentValueName: String, - code: String, - assignmentAst: Ast - ): Ast = - val parameter = NewAnnotationParameter().code(assignmentValueName) - val assign = NewAnnotationParameterAssign().code(code) - val assignChildren = List(Ast(parameter), assignmentAst) - setArgumentIndices(assignChildren) - Ast(assign) - .withChild(Ast(parameter)) - .withChild(assignmentAst) + /** Creates an AST that represents an annotation assignment with a name for the assigned value, + * its overall code, and the respective assignment AST. + */ + def annotationAssignmentAst( + assignmentValueName: String, + code: String, + assignmentAst: Ast + ): Ast = + val parameter = NewAnnotationParameter().code(assignmentValueName) + val assign = NewAnnotationParameterAssign().code(code) + val assignChildren = List(Ast(parameter), assignmentAst) + setArgumentIndices(assignChildren) + Ast(assign) + .withChild(Ast(parameter)) + .withChild(assignmentAst) - /** Creates an AST that represents an entire method, including its content. - */ - def methodAst( - method: NewMethod, - parameters: Seq[Ast], - body: Ast, - methodReturn: NewMethodReturn, - modifiers: Seq[NewModifier] = Nil - ): Ast = - methodAstWithAnnotations( - method, - parameters, - body, - methodReturn, - modifiers, - annotations = Nil - ) + /** Creates an AST that represents an entire method, including its content. + */ + def methodAst( + method: NewMethod, + parameters: Seq[Ast], + body: Ast, + methodReturn: NewMethodReturn, + modifiers: Seq[NewModifier] = Nil + ): Ast = + methodAstWithAnnotations( + method, + parameters, + body, + methodReturn, + modifiers, + annotations = Nil + ) - /** Creates an AST that represents an entire method, including its content and with support for - * both method and parameter annotations. - */ - def methodAstWithAnnotations( - method: NewMethod, - parameters: Seq[Ast], - body: Ast, - methodReturn: NewMethodReturn, - modifiers: Seq[NewModifier] = Nil, - annotations: Seq[Ast] = Nil - ): Ast = - Ast(method) - .withChildren(parameters) - .withChild(body) - .withChildren(modifiers.map(Ast(_))) - .withChildren(annotations) - .withChild(Ast(methodReturn)) + /** Creates an AST that represents an entire method, including its content and with support for + * both method and parameter annotations. + */ + def methodAstWithAnnotations( + method: NewMethod, + parameters: Seq[Ast], + body: Ast, + methodReturn: NewMethodReturn, + modifiers: Seq[NewModifier] = Nil, + annotations: Seq[Ast] = Nil + ): Ast = + Ast(method) + .withChildren(parameters) + .withChild(body) + .withChildren(modifiers.map(Ast(_))) + .withChildren(annotations) + .withChild(Ast(methodReturn)) - /** Creates an AST that represents a method stub, containing information about the method, its - * parameters, and the return type. - */ - def methodStubAst( - method: NewMethod, - parameters: Seq[NewMethodParameterIn], - methodReturn: NewMethodReturn, - modifiers: Seq[NewModifier] = Nil - ): Ast = - Ast(method) - .withChildren(parameters.map(Ast(_))) - .withChild(Ast(NewBlock())) - .withChildren(modifiers.map(Ast(_))) - .withChild(Ast(methodReturn)) + /** Creates an AST that represents a method stub, containing information about the method, its + * parameters, and the return type. + */ + def methodStubAst( + method: NewMethod, + parameters: Seq[NewMethodParameterIn], + methodReturn: NewMethodReturn, + modifiers: Seq[NewModifier] = Nil + ): Ast = + Ast(method) + .withChildren(parameters.map(Ast(_))) + .withChild(Ast(NewBlock())) + .withChildren(modifiers.map(Ast(_))) + .withChild(Ast(methodReturn)) - def staticInitMethodAst( - initAsts: List[Ast], - fullName: String, - signature: Option[String], - returnType: String, - fileName: Option[String] = None, - lineNumber: Option[Integer] = None, - columnNumber: Option[Integer] = None - ): Ast = - val methodNode = NewMethod() - .name(Defines.StaticInitMethodName) - .fullName(fullName) - .lineNumber(lineNumber) - .columnNumber(columnNumber) - if signature.isDefined then - methodNode.signature(signature.get) - if fileName.isDefined then - methodNode.filename(fileName.get) - val staticModifier = NewModifier().modifierType(ModifierTypes.STATIC) - val body = blockAst(NewBlock(), initAsts) - val methodReturn = newMethodReturnNode(returnType, None, None, None) - methodAst(methodNode, Nil, body, methodReturn, List(staticModifier)) - end staticInitMethodAst + def staticInitMethodAst( + initAsts: List[Ast], + fullName: String, + signature: Option[String], + returnType: String, + fileName: Option[String] = None, + lineNumber: Option[Integer] = None, + columnNumber: Option[Integer] = None + ): Ast = + val methodNode = NewMethod() + .name(Defines.StaticInitMethodName) + .fullName(fullName) + .lineNumber(lineNumber) + .columnNumber(columnNumber) + if signature.isDefined then + methodNode.signature(signature.get) + if fileName.isDefined then + methodNode.filename(fileName.get) + val staticModifier = NewModifier().modifierType(ModifierTypes.STATIC) + val body = blockAst(NewBlock(), initAsts) + val methodReturn = newMethodReturnNode(returnType, None, None, None) + methodAst(methodNode, Nil, body, methodReturn, List(staticModifier)) + end staticInitMethodAst - /** For a given return node and arguments, create an AST that represents the return instruction. - * The main purpose of this method is to automatically assign the correct argument indices. - */ - def returnAst(returnNode: NewReturn, arguments: Seq[Ast] = List()): Ast = - setArgumentIndices(arguments) - Ast(returnNode) - .withChildren(arguments) - .withArgEdges(returnNode, arguments.flatMap(_.root)) + /** For a given return node and arguments, create an AST that represents the return instruction. + * The main purpose of this method is to automatically assign the correct argument indices. + */ + def returnAst(returnNode: NewReturn, arguments: Seq[Ast] = List()): Ast = + setArgumentIndices(arguments) + Ast(returnNode) + .withChildren(arguments) + .withArgEdges(returnNode, arguments.flatMap(_.root)) - /** For a given node, condition AST and children ASTs, create an AST that represents the control - * structure. The main purpose of this method is to automatically assign the correct condition - * edges. - */ - def controlStructureAst( - controlStructureNode: NewControlStructure, - condition: Option[Ast], - children: Seq[Ast] = Seq(), - placeConditionLast: Boolean = false - ): Ast = - condition match - case Some(conditionAst) => - Ast(controlStructureNode) - .withChildren(if placeConditionLast then children :+ conditionAst - else conditionAst +: children) - .withConditionEdges(controlStructureNode, List(conditionAst.root).flatten) - case _ => - Ast(controlStructureNode) - .withChildren(children) + /** For a given node, condition AST and children ASTs, create an AST that represents the control + * structure. The main purpose of this method is to automatically assign the correct condition + * edges. + */ + def controlStructureAst( + controlStructureNode: NewControlStructure, + condition: Option[Ast], + children: Seq[Ast] = Seq(), + placeConditionLast: Boolean = false + ): Ast = + condition match + case Some(conditionAst) => + Ast(controlStructureNode) + .withChildren(if placeConditionLast then children :+ conditionAst + else conditionAst +: children) + .withConditionEdges(controlStructureNode, List(conditionAst.root).flatten) + case _ => + Ast(controlStructureNode) + .withChildren(children) - def wrapMultipleInBlock( - asts: Seq[Ast], - lineNumber: Option[Integer], - columnNumber: Option[Integer] - ): Ast = - asts.toList match - case Nil => blockAst(NewBlock().lineNumber(lineNumber).columnNumber(columnNumber)) - case ast :: Nil => ast - case astList => - blockAst(NewBlock().lineNumber(lineNumber).columnNumber(columnNumber), astList) + def wrapMultipleInBlock( + asts: Seq[Ast], + lineNumber: Option[Integer], + columnNumber: Option[Integer] + ): Ast = + asts.toList match + case Nil => blockAst(NewBlock().lineNumber(lineNumber).columnNumber(columnNumber)) + case ast :: Nil => ast + case astList => + blockAst(NewBlock().lineNumber(lineNumber).columnNumber(columnNumber), astList) - def whileAst( - condition: Option[Ast], - body: Seq[Ast], - code: Option[String] = None, - lineNumber: Option[Integer] = None, - columnNumber: Option[Integer] = None - ): Ast = - var whileNode = NewControlStructure() - .controlStructureType(ControlStructureTypes.WHILE) - .lineNumber(lineNumber) - .columnNumber(columnNumber) - if code.isDefined then - whileNode = whileNode.code(code.get) - controlStructureAst(whileNode, condition, body) + def whileAst( + condition: Option[Ast], + body: Seq[Ast], + code: Option[String] = None, + lineNumber: Option[Integer] = None, + columnNumber: Option[Integer] = None + ): Ast = + var whileNode = NewControlStructure() + .controlStructureType(ControlStructureTypes.WHILE) + .lineNumber(lineNumber) + .columnNumber(columnNumber) + if code.isDefined then + whileNode = whileNode.code(code.get) + controlStructureAst(whileNode, condition, body) - def doWhileAst( - condition: Option[Ast], - body: Seq[Ast], - code: Option[String] = None, - lineNumber: Option[Integer] = None, - columnNumber: Option[Integer] = None - ): Ast = - var doWhileNode = NewControlStructure() - .controlStructureType(ControlStructureTypes.DO) - .lineNumber(lineNumber) - .columnNumber(columnNumber) - if code.isDefined then - doWhileNode = doWhileNode.code(code.get) - controlStructureAst(doWhileNode, condition, body, placeConditionLast = true) + def doWhileAst( + condition: Option[Ast], + body: Seq[Ast], + code: Option[String] = None, + lineNumber: Option[Integer] = None, + columnNumber: Option[Integer] = None + ): Ast = + var doWhileNode = NewControlStructure() + .controlStructureType(ControlStructureTypes.DO) + .lineNumber(lineNumber) + .columnNumber(columnNumber) + if code.isDefined then + doWhileNode = doWhileNode.code(code.get) + controlStructureAst(doWhileNode, condition, body, placeConditionLast = true) - def forAst( - forNode: NewControlStructure, - locals: Seq[Ast], - initAsts: Seq[Ast], - conditionAsts: Seq[Ast], - updateAsts: Seq[Ast], - bodyAst: Ast - ): Ast = - forAst(forNode, locals, initAsts, conditionAsts, updateAsts, Seq(bodyAst)) + def forAst( + forNode: NewControlStructure, + locals: Seq[Ast], + initAsts: Seq[Ast], + conditionAsts: Seq[Ast], + updateAsts: Seq[Ast], + bodyAst: Ast + ): Ast = + forAst(forNode, locals, initAsts, conditionAsts, updateAsts, Seq(bodyAst)) - def forAst( - forNode: NewControlStructure, - locals: Seq[Ast], - initAsts: Seq[Ast], - conditionAsts: Seq[Ast], - updateAsts: Seq[Ast], - bodyAsts: Seq[Ast] - ): Ast = - val lineNumber = forNode.lineNumber - Ast(forNode) - .withChildren(locals) - .withChild(wrapMultipleInBlock(initAsts, lineNumber, forNode.columnNumber)) - .withChild(wrapMultipleInBlock(conditionAsts, lineNumber, forNode.columnNumber)) - .withChild(wrapMultipleInBlock(updateAsts, lineNumber, forNode.columnNumber)) - .withChildren(bodyAsts) - .withConditionEdges(forNode, conditionAsts.flatMap(_.root).toList) + def forAst( + forNode: NewControlStructure, + locals: Seq[Ast], + initAsts: Seq[Ast], + conditionAsts: Seq[Ast], + updateAsts: Seq[Ast], + bodyAsts: Seq[Ast] + ): Ast = + val lineNumber = forNode.lineNumber + Ast(forNode) + .withChildren(locals) + .withChild(wrapMultipleInBlock(initAsts, lineNumber, forNode.columnNumber)) + .withChild(wrapMultipleInBlock(conditionAsts, lineNumber, forNode.columnNumber)) + .withChild(wrapMultipleInBlock(updateAsts, lineNumber, forNode.columnNumber)) + .withChildren(bodyAsts) + .withConditionEdges(forNode, conditionAsts.flatMap(_.root).toList) - /** For the given try body, catch ASTs and finally AST, create a try-catch-finally AST with - * orders set correctly for the ossdataflow engine. - */ - def tryCatchAst( - tryNode: NewControlStructure, - tryBodyAst: Ast, - catchAsts: Seq[Ast], - finallyAst: Option[Ast] - ): Ast = - tryBodyAst.root.collect { case x: ExpressionNew => x }.foreach(_.order = 1) - catchAsts.flatMap(_.root).collect { case x: ExpressionNew => x }.foreach(_.order = 2) - finallyAst.flatMap(_.root).collect { case x: ExpressionNew => x }.foreach(_.order = 3) - Ast(tryNode) - .withChild(tryBodyAst) - .withChildren(catchAsts) - .withChildren(finallyAst.toList) + /** For the given try body, catch ASTs and finally AST, create a try-catch-finally AST with orders + * set correctly for the ossdataflow engine. + */ + def tryCatchAst( + tryNode: NewControlStructure, + tryBodyAst: Ast, + catchAsts: Seq[Ast], + finallyAst: Option[Ast] + ): Ast = + tryBodyAst.root.collect { case x: ExpressionNew => x }.foreach(_.order = 1) + catchAsts.flatMap(_.root).collect { case x: ExpressionNew => x }.foreach(_.order = 2) + finallyAst.flatMap(_.root).collect { case x: ExpressionNew => x }.foreach(_.order = 3) + Ast(tryNode) + .withChild(tryBodyAst) + .withChildren(catchAsts) + .withChildren(finallyAst.toList) - /** For a given block node and statement ASTs, create an AST that represents the block. The main - * purpose of this method is to increase the readability of the code which creates block asts. - */ - def blockAst(blockNode: NewBlock, statements: List[Ast] = List()): Ast = - Ast(blockNode).withChildren(statements) + /** For a given block node and statement ASTs, create an AST that represents the block. The main + * purpose of this method is to increase the readability of the code which creates block asts. + */ + def blockAst(blockNode: NewBlock, statements: List[Ast] = List()): Ast = + Ast(blockNode).withChildren(statements) - /** Create an abstract syntax tree for a call, including CPG-specific edges required for - * arguments and the receiver. - * - * Our call representation is inspired by ECMAScript, that is, in addition to arguments, a call - * has a base and a receiver. For languages other than Javascript, leave `receiver` empty for - * now. - * - * @param callNode - * the node that represents the entire call - * @param arguments - * arguments (without the base argument (instance)) - * @param base - * the value to use as `this` in the method call. - * @param receiver - * the object in which the property lookup is performed - */ - def callAst( - callNode: NewCall, - arguments: Seq[Ast] = List(), - base: Option[Ast] = None, - receiver: Option[Ast] = None - ): Ast = + /** Create an abstract syntax tree for a call, including CPG-specific edges required for arguments + * and the receiver. + * + * Our call representation is inspired by ECMAScript, that is, in addition to arguments, a call + * has a base and a receiver. For languages other than Javascript, leave `receiver` empty for + * now. + * + * @param callNode + * the node that represents the entire call + * @param arguments + * arguments (without the base argument (instance)) + * @param base + * the value to use as `this` in the method call. + * @param receiver + * the object in which the property lookup is performed + */ + def callAst( + callNode: NewCall, + arguments: Seq[Ast] = List(), + base: Option[Ast] = None, + receiver: Option[Ast] = None + ): Ast = - setArgumentIndices(arguments) + setArgumentIndices(arguments) - val baseRoot = base.flatMap(_.root).toList - val bse = base.getOrElse(Ast()) - baseRoot match - case List(x: ExpressionNew) => - x.argumentIndex = 0 - case _ => + val baseRoot = base.flatMap(_.root).toList + val bse = base.getOrElse(Ast()) + baseRoot match + case List(x: ExpressionNew) => + x.argumentIndex = 0 + case _ => - val receiverRoot = if receiver.isEmpty && base.nonEmpty then - baseRoot - else - val r = receiver.flatMap(_.root).toList - r match - case List(x: ExpressionNew) => - x.argumentIndex = -1 - case _ => - r + val receiverRoot = if receiver.isEmpty && base.nonEmpty then + baseRoot + else + val r = receiver.flatMap(_.root).toList + r match + case List(x: ExpressionNew) => + x.argumentIndex = -1 + case _ => + r - val rcvAst = receiver.getOrElse(Ast()) + val rcvAst = receiver.getOrElse(Ast()) - Ast(callNode) - .withChild(rcvAst) - .withChild(bse) - .withChildren(arguments) - .withArgEdges(callNode, baseRoot) - .withArgEdges(callNode, arguments.flatMap(_.root)) - .withReceiverEdges(callNode, receiverRoot) - end callAst + Ast(callNode) + .withChild(rcvAst) + .withChild(bse) + .withChildren(arguments) + .withArgEdges(callNode, baseRoot) + .withArgEdges(callNode, arguments.flatMap(_.root)) + .withReceiverEdges(callNode, receiverRoot) + end callAst - def setArgumentIndices(arguments: Seq[Ast]): Unit = - var currIndex = 1 - arguments.foreach { a => - a.root match - case Some(x: ExpressionNew) => - x.argumentIndex = currIndex - currIndex = currIndex + 1 - case None => // do nothing - case _ => - currIndex = currIndex + 1 - } + def setArgumentIndices(arguments: Seq[Ast]): Unit = + var currIndex = 1 + arguments.foreach { a => + a.root match + case Some(x: ExpressionNew) => + x.argumentIndex = currIndex + currIndex = currIndex + 1 + case None => // do nothing + case _ => + currIndex = currIndex + 1 + } - def withIndex[T, X](nodes: Seq[T])(f: (T, Int) => X): Seq[X] = - nodes.zipWithIndex.map { case (x, i) => - f(x, i + 1) - } + def withIndex[T, X](nodes: Seq[T])(f: (T, Int) => X): Seq[X] = + nodes.zipWithIndex.map { case (x, i) => + f(x, i + 1) + } - def withIndex[T, X](nodes: Array[T])(f: (T, Int) => X): Seq[X] = - nodes.toIndexedSeq.zipWithIndex.map { case (x, i) => - f(x, i + 1) - } + def withIndex[T, X](nodes: Array[T])(f: (T, Int) => X): Seq[X] = + nodes.toIndexedSeq.zipWithIndex.map { case (x, i) => + f(x, i + 1) + } - def withArgumentIndex[T <: ExpressionNew](node: T, argIdxOpt: Option[Int]): T = - argIdxOpt match - case Some(argIdx) => - node.argumentIndex = argIdx - node - case None => node + def withArgumentIndex[T <: ExpressionNew](node: T, argIdxOpt: Option[Int]): T = + argIdxOpt match + case Some(argIdx) => + node.argumentIndex = argIdx + node + case None => node - def withArgumentName[T <: ExpressionNew](node: T, argNameOpt: Option[String]): T = - node.argumentName = argNameOpt - node + def withArgumentName[T <: ExpressionNew](node: T, argNameOpt: Option[String]): T = + node.argumentName = argNameOpt + node - /** Absolute path for the given file name - */ - def absolutePath(filename: String): String = - better.files.File(filename).path.toAbsolutePath.normalize().toString + /** Absolute path for the given file name + */ + def absolutePath(filename: String): String = + better.files.File(filename).path.toAbsolutePath.normalize().toString - def nextClosureName(): String = s"${Defines.ClosurePrefix}${closureKeyPool.next}" + def nextClosureName(): String = s"${Defines.ClosurePrefix}${closureKeyPool.next}" end AstCreatorBase diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/AstNodeBuilder.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/AstNodeBuilder.scala index e19b6331..a1050657 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/AstNodeBuilder.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/AstNodeBuilder.scala @@ -25,324 +25,324 @@ import io.shiftleft.codepropertygraph.generated.nodes.Block.{PropertyDefaults as import org.apache.commons.lang.StringUtils import io.appthreat.x2cpg.utils.NodeBuilders.newMethodReturnNode trait AstNodeBuilder[Node, NodeProcessor]: - this: NodeProcessor => - protected def line(node: Node): Option[Integer] - protected def column(node: Node): Option[Integer] - protected def lineEnd(node: Node): Option[Integer] - protected def columnEnd(element: Node): Option[Integer] + this: NodeProcessor => + protected def line(node: Node): Option[Integer] + protected def column(node: Node): Option[Integer] + protected def lineEnd(node: Node): Option[Integer] + protected def columnEnd(element: Node): Option[Integer] - private val MinCodeLength: Int = 50 - private val DefaultMaxCodeLength: Int = 1000 - // maximum length of code fields in number of characters - private lazy val MaxCodeLength: Int = - sys.env.get("CHEN_MAX_CODE_LENGTH").flatMap(_.toIntOption).getOrElse(DefaultMaxCodeLength) + private val MinCodeLength: Int = 50 + private val DefaultMaxCodeLength: Int = 1000 + // maximum length of code fields in number of characters + private lazy val MaxCodeLength: Int = + sys.env.get("CHEN_MAX_CODE_LENGTH").flatMap(_.toIntOption).getOrElse(DefaultMaxCodeLength) - protected def code(node: Node): String + protected def code(node: Node): String - protected def shortenCode(code: String): String = - StringUtils.abbreviate(code, math.max(MinCodeLength, MaxCodeLength)) + protected def shortenCode(code: String): String = + StringUtils.abbreviate(code, math.max(MinCodeLength, MaxCodeLength)) - protected def offset(node: Node): Option[(Int, Int)] = None + protected def offset(node: Node): Option[(Int, Int)] = None - protected def unknownNode(node: Node, code: String): NewUnknown = - NewUnknown() - .parserTypeName(node.getClass.getSimpleName) - .code(code) - .lineNumber(line(node)) - .columnNumber(column(node)) - - protected def annotationNode( - node: Node, - code: String, - name: String, - fullName: String - ): NewAnnotation = - NewAnnotation() - .code(code) - .name(name) - .fullName(fullName) - .lineNumber(line(node)) - .columnNumber(column(node)) + protected def unknownNode(node: Node, code: String): NewUnknown = + NewUnknown() + .parserTypeName(node.getClass.getSimpleName) + .code(code) + .lineNumber(line(node)) + .columnNumber(column(node)) - protected def methodRefNode( - node: Node, - code: String, - methodFullName: String, - typeFullName: String - ): NewMethodRef = - NewMethodRef() - .code(code) - .methodFullName(methodFullName) - .typeFullName(typeFullName) - .lineNumber(line(node)) - .columnNumber(column(node)) - - protected def memberNode( - node: Node, - name: String, - code: String, - typeFullName: String - ): NewMember = - memberNode(node, name, code, typeFullName, Seq()) + protected def annotationNode( + node: Node, + code: String, + name: String, + fullName: String + ): NewAnnotation = + NewAnnotation() + .code(code) + .name(name) + .fullName(fullName) + .lineNumber(line(node)) + .columnNumber(column(node)) - protected def memberNode( - node: Node, - name: String, - code: String, - typeFullName: String, - dynamicTypeHints: Seq[String] = Seq() - ): NewMember = - NewMember() - .code(code) - .name(name) - .typeFullName(typeFullName) - .dynamicTypeHintFullName(dynamicTypeHints) - .lineNumber(line(node)) - .columnNumber(column(node)) + protected def methodRefNode( + node: Node, + code: String, + methodFullName: String, + typeFullName: String + ): NewMethodRef = + NewMethodRef() + .code(code) + .methodFullName(methodFullName) + .typeFullName(typeFullName) + .lineNumber(line(node)) + .columnNumber(column(node)) - protected def newImportNode( - code: String, - importedEntity: String, - importedAs: String, - include: Node - ): NewImport = - NewImport() - .code(code) - .importedEntity(importedEntity) - .importedAs(importedAs) - .lineNumber(line(include)) - .columnNumber(column(include)) + protected def memberNode( + node: Node, + name: String, + code: String, + typeFullName: String + ): NewMember = + memberNode(node, name, code, typeFullName, Seq()) - protected def literalNode( - node: Node, - code: String, - typeFullName: String, - dynamicTypeHints: Seq[String] = Seq() - ): NewLiteral = - NewLiteral() - .code(code) - .typeFullName(typeFullName) - .dynamicTypeHintFullName(dynamicTypeHints) - .lineNumber(line(node)) - .columnNumber(column(node)) + protected def memberNode( + node: Node, + name: String, + code: String, + typeFullName: String, + dynamicTypeHints: Seq[String] = Seq() + ): NewMember = + NewMember() + .code(code) + .name(name) + .typeFullName(typeFullName) + .dynamicTypeHintFullName(dynamicTypeHints) + .lineNumber(line(node)) + .columnNumber(column(node)) - protected def typeRefNode(node: Node, code: String, typeFullName: String): NewTypeRef = - NewTypeRef() - .code(code) - .typeFullName(typeFullName) - .lineNumber(line(node)) - .columnNumber(column(node)) + protected def newImportNode( + code: String, + importedEntity: String, + importedAs: String, + include: Node + ): NewImport = + NewImport() + .code(code) + .importedEntity(importedEntity) + .importedAs(importedAs) + .lineNumber(line(include)) + .columnNumber(column(include)) - def typeDeclNode( - node: Node, - name: String, - fullName: String, - fileName: String, - inheritsFrom: Seq[String], - alias: Option[String] - ): NewTypeDecl = - typeDeclNode(node, name, fullName, fileName, name, "", "", inheritsFrom, alias) + protected def literalNode( + node: Node, + code: String, + typeFullName: String, + dynamicTypeHints: Seq[String] = Seq() + ): NewLiteral = + NewLiteral() + .code(code) + .typeFullName(typeFullName) + .dynamicTypeHintFullName(dynamicTypeHints) + .lineNumber(line(node)) + .columnNumber(column(node)) - protected def typeDeclNode( - node: Node, - name: String, - fullName: String, - filename: String, - code: String, - astParentType: String = "", - astParentFullName: String = "", - inherits: Seq[String] = Seq.empty, - alias: Option[String] = None - ): NewTypeDecl = - NewTypeDecl() - .name(name) - .fullName(fullName) - .code(code) - .isExternal(false) - .filename(filename) - .astParentType(astParentType) - .astParentFullName(astParentFullName) - .inheritsFromTypeFullName(inherits) - .aliasTypeFullName(alias) - .lineNumber(line(node)) - .columnNumber(column(node)) + protected def typeRefNode(node: Node, code: String, typeFullName: String): NewTypeRef = + NewTypeRef() + .code(code) + .typeFullName(typeFullName) + .lineNumber(line(node)) + .columnNumber(column(node)) - protected def parameterInNode( - node: Node, - name: String, - code: String, - index: Int, - isVariadic: Boolean, - evaluationStrategy: String, - typeFullName: String - ): NewMethodParameterIn = - parameterInNode(node, name, code, index, isVariadic, evaluationStrategy, Some(typeFullName)) + def typeDeclNode( + node: Node, + name: String, + fullName: String, + fileName: String, + inheritsFrom: Seq[String], + alias: Option[String] + ): NewTypeDecl = + typeDeclNode(node, name, fullName, fileName, name, "", "", inheritsFrom, alias) - protected def parameterInNode( - node: Node, - name: String, - code: String, - index: Int, - isVariadic: Boolean, - evaluationStrategy: String, - typeFullName: Option[String] = None - ): NewMethodParameterIn = - NewMethodParameterIn() - .name(name) - .code(code) - .index(index) - .order(index) - .isVariadic(isVariadic) - .evaluationStrategy(evaluationStrategy) - .lineNumber(line(node)) - .columnNumber(column(node)) - .typeFullName(typeFullName.getOrElse("ANY")) + protected def typeDeclNode( + node: Node, + name: String, + fullName: String, + filename: String, + code: String, + astParentType: String = "", + astParentFullName: String = "", + inherits: Seq[String] = Seq.empty, + alias: Option[String] = None + ): NewTypeDecl = + NewTypeDecl() + .name(name) + .fullName(fullName) + .code(code) + .isExternal(false) + .filename(filename) + .astParentType(astParentType) + .astParentFullName(astParentFullName) + .inheritsFromTypeFullName(inherits) + .aliasTypeFullName(alias) + .lineNumber(line(node)) + .columnNumber(column(node)) - def callNode( - node: Node, - code: String, - name: String, - methodFullName: String, - dispatchType: String - ): NewCall = - callNode(node, code, name, methodFullName, dispatchType, None, None) + protected def parameterInNode( + node: Node, + name: String, + code: String, + index: Int, + isVariadic: Boolean, + evaluationStrategy: String, + typeFullName: String + ): NewMethodParameterIn = + parameterInNode(node, name, code, index, isVariadic, evaluationStrategy, Some(typeFullName)) - def callNode( - node: Node, - code: String, - name: String, - methodFullName: String, - dispatchType: String, - signature: Option[String], - typeFullName: Option[String] - ): NewCall = - val out = - NewCall() - .code(code) - .name(name) - .methodFullName(methodFullName) - .dispatchType(dispatchType) - .lineNumber(line(node)) - .columnNumber(column(node)) - signature.foreach { s => out.signature(s) } - typeFullName.foreach { t => out.typeFullName(t) } - out - end callNode + protected def parameterInNode( + node: Node, + name: String, + code: String, + index: Int, + isVariadic: Boolean, + evaluationStrategy: String, + typeFullName: Option[String] = None + ): NewMethodParameterIn = + NewMethodParameterIn() + .name(name) + .code(code) + .index(index) + .order(index) + .isVariadic(isVariadic) + .evaluationStrategy(evaluationStrategy) + .lineNumber(line(node)) + .columnNumber(column(node)) + .typeFullName(typeFullName.getOrElse("ANY")) - protected def returnNode(node: Node, code: String): NewReturn = - NewReturn() - .code(code) - .lineNumber(line(node)) - .columnNumber(column(node)) + def callNode( + node: Node, + code: String, + name: String, + methodFullName: String, + dispatchType: String + ): NewCall = + callNode(node, code, name, methodFullName, dispatchType, None, None) - protected def controlStructureNode( - node: Node, - controlStructureType: String, - code: String - ): NewControlStructure = - NewControlStructure() - .parserTypeName(node.getClass.getSimpleName) - .controlStructureType(controlStructureType) + def callNode( + node: Node, + code: String, + name: String, + methodFullName: String, + dispatchType: String, + signature: Option[String], + typeFullName: Option[String] + ): NewCall = + val out = + NewCall() .code(code) + .name(name) + .methodFullName(methodFullName) + .dispatchType(dispatchType) .lineNumber(line(node)) .columnNumber(column(node)) + signature.foreach { s => out.signature(s) } + typeFullName.foreach { t => out.typeFullName(t) } + out + end callNode - protected def blockNode(node: Node): NewBlock = - blockNode(node, BlockDefaults.Code, BlockDefaults.TypeFullName) + protected def returnNode(node: Node, code: String): NewReturn = + NewReturn() + .code(code) + .lineNumber(line(node)) + .columnNumber(column(node)) - protected def blockNode(node: Node, code: String, typeFullName: String): NewBlock = - NewBlock() - .code(code) - .typeFullName(typeFullName) - .lineNumber(line(node)) - .columnNumber(column(node)) + protected def controlStructureNode( + node: Node, + controlStructureType: String, + code: String + ): NewControlStructure = + NewControlStructure() + .parserTypeName(node.getClass.getSimpleName) + .controlStructureType(controlStructureType) + .code(code) + .lineNumber(line(node)) + .columnNumber(column(node)) - protected def fieldIdentifierNode(node: Node, name: String, code: String): NewFieldIdentifier = - NewFieldIdentifier() - .canonicalName(name) - .code(code) - .lineNumber(line(node)) - .columnNumber(column(node)) + protected def blockNode(node: Node): NewBlock = + blockNode(node, BlockDefaults.Code, BlockDefaults.TypeFullName) - protected def localNode( - node: Node, - name: String, - code: String, - typeFullName: String, - closureBindingId: Option[String] = None - ): NewLocal = - NewLocal() - .name(name) - .code(code) - .typeFullName(typeFullName) - .closureBindingId(closureBindingId) - .lineNumber(line(node)) - .columnNumber(column(node)) + protected def blockNode(node: Node, code: String, typeFullName: String): NewBlock = + NewBlock() + .code(code) + .typeFullName(typeFullName) + .lineNumber(line(node)) + .columnNumber(column(node)) - protected def identifierNode( - node: Node, - name: String, - code: String, - typeFullName: String, - dynamicTypeHints: Seq[String] = Seq() - ): NewIdentifier = - NewIdentifier() - .name(name) - .typeFullName(typeFullName) - .code(code) - .dynamicTypeHintFullName(dynamicTypeHints) - .lineNumber(line(node)) - .columnNumber(column(node)) + protected def fieldIdentifierNode(node: Node, name: String, code: String): NewFieldIdentifier = + NewFieldIdentifier() + .canonicalName(name) + .code(code) + .lineNumber(line(node)) + .columnNumber(column(node)) - def methodNode( - node: Node, - name: String, - fullName: String, - signature: String, - fileName: String - ): NewMethod = - methodNode(node, name, name, fullName, Some(signature), fileName) + protected def localNode( + node: Node, + name: String, + code: String, + typeFullName: String, + closureBindingId: Option[String] = None + ): NewLocal = + NewLocal() + .name(name) + .code(code) + .typeFullName(typeFullName) + .closureBindingId(closureBindingId) + .lineNumber(line(node)) + .columnNumber(column(node)) - protected def methodNode( - node: Node, - name: String, - code: String, - fullName: String, - signature: Option[String], - fileName: String, - astParentType: Option[String] = None, - astParentFullName: Option[String] = None - ): NewMethod = - val node_ = - NewMethod() - .name(StringUtils.normalizeSpace(name)) - .code(code) - .fullName(StringUtils.normalizeSpace(fullName)) - .filename(fileName) - .astParentType(astParentType.getOrElse("")) - .astParentFullName(astParentFullName.getOrElse("")) - .isExternal(false) - .lineNumber(line(node)) - .columnNumber(column(node)) - .lineNumberEnd(lineEnd(node)) - .columnNumberEnd(columnEnd(node)) - signature.foreach { s => node_.signature(StringUtils.normalizeSpace(s)) } - node_ - end methodNode + protected def identifierNode( + node: Node, + name: String, + code: String, + typeFullName: String, + dynamicTypeHints: Seq[String] = Seq() + ): NewIdentifier = + NewIdentifier() + .name(name) + .typeFullName(typeFullName) + .code(code) + .dynamicTypeHintFullName(dynamicTypeHints) + .lineNumber(line(node)) + .columnNumber(column(node)) - protected def methodReturnNode(node: Node, typeFullName: String): NewMethodReturn = - newMethodReturnNode(typeFullName, None, line(node), column(node)) + def methodNode( + node: Node, + name: String, + fullName: String, + signature: String, + fileName: String + ): NewMethod = + methodNode(node, name, name, fullName, Some(signature), fileName) - protected def jumpTargetNode( - node: Node, - name: String, - code: String, - parserTypeName: Option[String] = None - ): NewJumpTarget = - NewJumpTarget() - .parserTypeName(parserTypeName.getOrElse(node.getClass.getSimpleName)) - .name(name) + protected def methodNode( + node: Node, + name: String, + code: String, + fullName: String, + signature: Option[String], + fileName: String, + astParentType: Option[String] = None, + astParentFullName: Option[String] = None + ): NewMethod = + val node_ = + NewMethod() + .name(StringUtils.normalizeSpace(name)) .code(code) + .fullName(StringUtils.normalizeSpace(fullName)) + .filename(fileName) + .astParentType(astParentType.getOrElse("")) + .astParentFullName(astParentFullName.getOrElse("")) + .isExternal(false) .lineNumber(line(node)) .columnNumber(column(node)) + .lineNumberEnd(lineEnd(node)) + .columnNumberEnd(columnEnd(node)) + signature.foreach { s => node_.signature(StringUtils.normalizeSpace(s)) } + node_ + end methodNode + + protected def methodReturnNode(node: Node, typeFullName: String): NewMethodReturn = + newMethodReturnNode(typeFullName, None, line(node), column(node)) + + protected def jumpTargetNode( + node: Node, + name: String, + code: String, + parserTypeName: Option[String] = None + ): NewJumpTarget = + NewJumpTarget() + .parserTypeName(parserTypeName.getOrElse(node.getClass.getSimpleName)) + .name(name) + .code(code) + .lineNumber(line(node)) + .columnNumber(column(node)) end AstNodeBuilder diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/Defines.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/Defines.scala index 3bfd95ad..6aaca645 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/Defines.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/Defines.scala @@ -1,32 +1,34 @@ package io.appthreat.x2cpg object Defines: - // The following two defines should be used for type and method full names to - // indicate unresolved static type information. Using them enables - // the closed source backend to apply policies in a less strict fashion. - // The most notable case is the METHOD_FULL_NAME property of a CALL node. - // As example consider a call to a method `foo(someArg)` which cannot be - // resolved. The METHOD_FULL_NAME should be given as - // ".foo:(1)". If the namespace is known - // the METHOD_FULL_NAME should be given as - // "some.namespace.foo:(1)". Thereby the number in parenthesis - // is the number of call arguments. - // Note that this schema and thus the defines only makes sense for statically - // typed languages with a package/namespace structure like Java, CSharp, etc.. - val UnresolvedNamespace = "" - val UnresolvedSignature = "" + // The following two defines should be used for type and method full names to + // indicate unresolved static type information. Using them enables + // the closed source backend to apply policies in a less strict fashion. + // The most notable case is the METHOD_FULL_NAME property of a CALL node. + // As example consider a call to a method `foo(someArg)` which cannot be + // resolved. The METHOD_FULL_NAME should be given as + // ".foo:(1)". If the namespace is known + // the METHOD_FULL_NAME should be given as + // "some.namespace.foo:(1)". Thereby the number in parenthesis + // is the number of call arguments. + // Note that this schema and thus the defines only makes sense for statically + // typed languages with a package/namespace structure like Java, CSharp, etc.. + val Any = "ANY" + val UnresolvedNamespace = "" + val UnresolvedSignature = "" - // Name of the synthetic, static method that contains the initialization of member variables. - val StaticInitMethodName = "" + // Name of the synthetic, static method that contains the initialization of member variables. + val StaticInitMethodName = "" - // Name of the constructor. - val ConstructorMethodName = "" + // Name of the constructor. + val ConstructorMethodName = "" - // In some languages like Javascript dynamic calls do not provide any statically known - // method/function interface information. In those cases please use this value. - val DynamicCallUnknownFullName = "" + // In some languages like Javascript dynamic calls do not provide any statically known + // method/function interface information. In those cases please use this value. + val DynamicCallUnknownFullName = "" - val LeftAngularBracket = "<" - val Unknown = "" - val ClosurePrefix = "" + val LeftAngularBracket = "<" + val Unknown = "" + // Anonymous functions, lambdas, and closures, follow the naming scheme of $LambdaPrefix$int + val ClosurePrefix = "" end Defines diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/Imports.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/Imports.scala index de16e86c..713a6165 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/Imports.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/Imports.scala @@ -6,15 +6,15 @@ import overflowdb.BatchedUpdate.DiffGraphBuilder object Imports: - def createImportNodeAndLink( - importedEntity: String, - importedAs: String, - call: Option[CallBase], - diffGraph: DiffGraphBuilder - ): NewImport = - val importNode = NewImport() - .importedAs(importedAs) - .importedEntity(importedEntity) - diffGraph.addNode(importNode) - call.foreach { c => diffGraph.addEdge(c, importNode, EdgeTypes.IS_CALL_FOR_IMPORT) } - importNode + def createImportNodeAndLink( + importedEntity: String, + importedAs: String, + call: Option[CallBase], + diffGraph: DiffGraphBuilder + ): NewImport = + val importNode = NewImport() + .importedAs(importedAs) + .importedEntity(importedEntity) + diffGraph.addNode(importNode) + call.foreach { c => diffGraph.addEdge(c, importNode, EdgeTypes.IS_CALL_FOR_IMPORT) } + importNode diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/SourceFiles.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/SourceFiles.scala index ae8eb912..3785b60b 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/SourceFiles.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/SourceFiles.scala @@ -9,141 +9,141 @@ import java.nio.file.Paths object SourceFiles: - private val logger = LoggerFactory.getLogger(getClass) - - private def isIgnoredByFileList(filePath: String, config: X2CpgConfig[?]): Boolean = - val isInIgnoredFiles = config.ignoredFiles.exists { - case ignorePath if File(ignorePath).isDirectory => filePath.startsWith(ignorePath) - case ignorePath => filePath == ignorePath - } - if isInIgnoredFiles then - logger.debug(s"'$filePath' ignored (--exclude)") - true + private val logger = LoggerFactory.getLogger(getClass) + + private def isIgnoredByFileList(filePath: String, config: X2CpgConfig[?]): Boolean = + val isInIgnoredFiles = config.ignoredFiles.exists { + case ignorePath if File(ignorePath).isDirectory => filePath.startsWith(ignorePath) + case ignorePath => filePath == ignorePath + } + if isInIgnoredFiles then + logger.debug(s"'$filePath' ignored (--exclude)") + true + else + false + + private def isIgnoredByDefault(filePath: String, config: X2CpgConfig[?]): Boolean = + val relPath = toRelativePath(filePath, config.inputPath) + if config.defaultIgnoredFilesRegex.exists(_.matches(relPath)) || File( + filePath + ).isSymbolicLink + then + logger.debug(s"'$relPath' ignored by default") + true + else + false + + private def isIgnoredByRegex(filePath: String, config: X2CpgConfig[?]): Boolean = + val relPath = toRelativePath(filePath, config.inputPath) + val isInIgnoredFilesRegex = config.ignoredFilesRegex.matches(relPath) + if isInIgnoredFilesRegex then + logger.debug(s"'$relPath' ignored (--exclude-regex)") + true + else + false + + private def filterFiles(files: List[String], config: X2CpgConfig[?]): List[String] = + files.filter { + case filePath if isIgnoredByDefault(filePath, config) => false + case filePath if isIgnoredByFileList(filePath, config) => false + case filePath if isIgnoredByRegex(filePath, config) => false + case _ => true + } + + /** For a given input path, determine all source files by inspecting filename extensions. + */ + def determine(inputPath: String, sourceFileExtensions: Set[String]): List[String] = + determine(Set(inputPath), sourceFileExtensions) + + /** For a given input path, determine all source files by inspecting filename extensions and + * filter the result according to the given config (by its ignoredFilesRegex and ignoredFiles). + */ + def determine( + inputPath: String, + sourceFileExtensions: Set[String], + config: X2CpgConfig[?] + ): List[String] = + determine(Set(inputPath), sourceFileExtensions, config) + + /** For given input paths, determine all source files by inspecting filename extensions and filter + * the result according to the given config (by its ignoredFilesRegex and ignoredFiles). + */ + def determine( + inputPaths: Set[String], + sourceFileExtensions: Set[String], + config: X2CpgConfig[?] + ): List[String] = + filterFiles(determine(inputPaths, sourceFileExtensions), config) + + /** For a given array of input paths, determine all source files by inspecting filename + * extensions. + */ + def determine(inputPaths: Set[String], sourceFileExtensions: Set[String]): List[String] = + def hasSourceFileExtension(file: File): Boolean = + file.extension.exists(sourceFileExtensions.contains) + + val inputFiles = inputPaths.map(File(_)) + assertAllExist(inputFiles) + + val (dirs, files) = inputFiles.partition(_.isDirectory) + + val matchingFiles = files.filter(hasSourceFileExtension).map(_.toString) + val matchingFilesFromDirs = dirs + .flatMap(_.listRecursively(VisitOptions.default)) + .filter(hasSourceFileExtension) + .map(_.pathAsString) + + (matchingFiles ++ matchingFilesFromDirs).toList.sorted + + /** Attempting to analyse source paths that do not exist is a hard error. Terminate execution + * early to avoid unexpected and hard-to-debug issues in the results. + */ + private def assertAllExist(files: Set[File]): Unit = + val (existant, nonExistant) = files.partition(_.isReadable) + val nonReadable = existant.filterNot(_.isReadable) + + if nonExistant.nonEmpty || nonReadable.nonEmpty then + logErrorWithPaths("Source input paths do not exist", nonExistant.map(_.canonicalPath)) + + logErrorWithPaths( + "Source input paths exist, but are not readable", + nonReadable.map(_.canonicalPath) + ) + + throw FileNotFoundException("Invalid source paths provided") + + private def logErrorWithPaths(message: String, paths: Iterable[String]): Unit = + val pathsArray = paths.toArray.sorted + + pathsArray.lengthCompare(1) match + case cmp if cmp < 0 => // pathsArray is empty, so don't log anything + case cmp if cmp == 0 => logger.debug(s"$message: ${paths.head}") + + case cmp => + val errorMessage = (message +: pathsArray.map(path => s"- $path")).mkString("\n") + logger.debug(errorMessage) + + /** Constructs an absolute path against rootPath. If the given path is already absolute this path + * is returned unaltered. Otherwise, "rootPath / path" is returned. + */ + def toAbsolutePath(path: String, rootPath: String): String = + val absolutePath = Paths.get(path) match + case p if p.isAbsolute => p + case _ if rootPath.endsWith(path) => Paths.get(rootPath) + case p => Paths.get(rootPath, p.toString) + absolutePath.normalize().toString + + /** Constructs a relative path against rootPath. If the given path is not inside rootPath, path is + * returned unaltered. Otherwise, the path relative to rootPath is returned. + */ + def toRelativePath(path: String, rootPath: String): String = + if path.startsWith(rootPath) then + val absolutePath = Paths.get(path).toAbsolutePath + val projectPath = Paths.get(rootPath).toAbsolutePath + if absolutePath.compareTo(projectPath) == 0 then + absolutePath.getFileName.toString else - false - - private def isIgnoredByDefault(filePath: String, config: X2CpgConfig[?]): Boolean = - val relPath = toRelativePath(filePath, config.inputPath) - if config.defaultIgnoredFilesRegex.exists(_.matches(relPath)) || File( - filePath - ).isSymbolicLink - then - logger.debug(s"'$relPath' ignored by default") - true - else - false - - private def isIgnoredByRegex(filePath: String, config: X2CpgConfig[?]): Boolean = - val relPath = toRelativePath(filePath, config.inputPath) - val isInIgnoredFilesRegex = config.ignoredFilesRegex.matches(relPath) - if isInIgnoredFilesRegex then - logger.debug(s"'$relPath' ignored (--exclude-regex)") - true - else - false - - private def filterFiles(files: List[String], config: X2CpgConfig[?]): List[String] = - files.filter { - case filePath if isIgnoredByDefault(filePath, config) => false - case filePath if isIgnoredByFileList(filePath, config) => false - case filePath if isIgnoredByRegex(filePath, config) => false - case _ => true - } - - /** For a given input path, determine all source files by inspecting filename extensions. - */ - def determine(inputPath: String, sourceFileExtensions: Set[String]): List[String] = - determine(Set(inputPath), sourceFileExtensions) - - /** For a given input path, determine all source files by inspecting filename extensions and - * filter the result according to the given config (by its ignoredFilesRegex and ignoredFiles). - */ - def determine( - inputPath: String, - sourceFileExtensions: Set[String], - config: X2CpgConfig[?] - ): List[String] = - determine(Set(inputPath), sourceFileExtensions, config) - - /** For given input paths, determine all source files by inspecting filename extensions and - * filter the result according to the given config (by its ignoredFilesRegex and ignoredFiles). - */ - def determine( - inputPaths: Set[String], - sourceFileExtensions: Set[String], - config: X2CpgConfig[?] - ): List[String] = - filterFiles(determine(inputPaths, sourceFileExtensions), config) - - /** For a given array of input paths, determine all source files by inspecting filename - * extensions. - */ - def determine(inputPaths: Set[String], sourceFileExtensions: Set[String]): List[String] = - def hasSourceFileExtension(file: File): Boolean = - file.extension.exists(sourceFileExtensions.contains) - - val inputFiles = inputPaths.map(File(_)) - assertAllExist(inputFiles) - - val (dirs, files) = inputFiles.partition(_.isDirectory) - - val matchingFiles = files.filter(hasSourceFileExtension).map(_.toString) - val matchingFilesFromDirs = dirs - .flatMap(_.listRecursively(VisitOptions.default)) - .filter(hasSourceFileExtension) - .map(_.pathAsString) - - (matchingFiles ++ matchingFilesFromDirs).toList.sorted - - /** Attempting to analyse source paths that do not exist is a hard error. Terminate execution - * early to avoid unexpected and hard-to-debug issues in the results. - */ - private def assertAllExist(files: Set[File]): Unit = - val (existant, nonExistant) = files.partition(_.isReadable) - val nonReadable = existant.filterNot(_.isReadable) - - if nonExistant.nonEmpty || nonReadable.nonEmpty then - logErrorWithPaths("Source input paths do not exist", nonExistant.map(_.canonicalPath)) - - logErrorWithPaths( - "Source input paths exist, but are not readable", - nonReadable.map(_.canonicalPath) - ) - - throw FileNotFoundException("Invalid source paths provided") - - private def logErrorWithPaths(message: String, paths: Iterable[String]): Unit = - val pathsArray = paths.toArray.sorted - - pathsArray.lengthCompare(1) match - case cmp if cmp < 0 => // pathsArray is empty, so don't log anything - case cmp if cmp == 0 => logger.debug(s"$message: ${paths.head}") - - case cmp => - val errorMessage = (message +: pathsArray.map(path => s"- $path")).mkString("\n") - logger.debug(errorMessage) - - /** Constructs an absolute path against rootPath. If the given path is already absolute this - * path is returned unaltered. Otherwise, "rootPath / path" is returned. - */ - def toAbsolutePath(path: String, rootPath: String): String = - val absolutePath = Paths.get(path) match - case p if p.isAbsolute => p - case _ if rootPath.endsWith(path) => Paths.get(rootPath) - case p => Paths.get(rootPath, p.toString) - absolutePath.normalize().toString - - /** Constructs a relative path against rootPath. If the given path is not inside rootPath, path - * is returned unaltered. Otherwise, the path relative to rootPath is returned. - */ - def toRelativePath(path: String, rootPath: String): String = - if path.startsWith(rootPath) then - val absolutePath = Paths.get(path).toAbsolutePath - val projectPath = Paths.get(rootPath).toAbsolutePath - if absolutePath.compareTo(projectPath) == 0 then - absolutePath.getFileName.toString - else - projectPath.relativize(absolutePath).toString - else - path + projectPath.relativize(absolutePath).toString + else + path end SourceFiles diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/X2Cpg.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/X2Cpg.scala index 5c6dc86b..8fcb6d46 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/X2Cpg.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/X2Cpg.scala @@ -16,54 +16,54 @@ import scala.util.matching.Regex import scala.util.{Failure, Success, Try} object X2CpgConfig: - def defaultOutputPath: String = "cpg.bin" + def defaultOutputPath: String = "cpg.bin" trait X2CpgConfig[R <: X2CpgConfig[R]]: - var inputPath: String = "" - var outputPath: String = X2CpgConfig.defaultOutputPath - - def withInputPath(inputPath: String): R = - this.inputPath = Paths.get(inputPath).toAbsolutePath.normalize().toString - this.asInstanceOf[R] - - def withOutputPath(x: String): R = - this.outputPath = x - this.asInstanceOf[R] - - var defaultIgnoredFilesRegex: Seq[Regex] = Seq.empty - var ignoredFilesRegex: Regex = "".r - var ignoredFiles: Seq[String] = Seq.empty - - def withDefaultIgnoredFilesRegex(x: Seq[Regex]): R = - this.defaultIgnoredFilesRegex = x - this.asInstanceOf[R] - - def withIgnoredFilesRegex(x: String): R = - this.ignoredFilesRegex = x.r - this.asInstanceOf[R] - - def withIgnoredFiles(x: Seq[String]): R = - this.ignoredFiles = x.map(createPathForIgnore) - this.asInstanceOf[R] - - def createPathForIgnore(ignore: String): String = - val path = Paths.get(ignore) - if path.isAbsolute then path.toString - else Paths.get(inputPath, ignore).toAbsolutePath.normalize().toString - - var schemaValidation: ValidationMode = ValidationMode.Disabled - - def withSchemaValidation(value: ValidationMode): R = - this.schemaValidation = value - this.asInstanceOf[R] - - def withInheritedFields(config: R): R = - this.inputPath = config.inputPath - this.outputPath = config.outputPath - this.defaultIgnoredFilesRegex = config.defaultIgnoredFilesRegex - this.ignoredFilesRegex = config.ignoredFilesRegex - this.ignoredFiles = config.ignoredFiles - this.asInstanceOf[R] + var inputPath: String = "" + var outputPath: String = X2CpgConfig.defaultOutputPath + + def withInputPath(inputPath: String): R = + this.inputPath = Paths.get(inputPath).toAbsolutePath.normalize().toString + this.asInstanceOf[R] + + def withOutputPath(x: String): R = + this.outputPath = x + this.asInstanceOf[R] + + var defaultIgnoredFilesRegex: Seq[Regex] = Seq.empty + var ignoredFilesRegex: Regex = "".r + var ignoredFiles: Seq[String] = Seq.empty + + def withDefaultIgnoredFilesRegex(x: Seq[Regex]): R = + this.defaultIgnoredFilesRegex = x + this.asInstanceOf[R] + + def withIgnoredFilesRegex(x: String): R = + this.ignoredFilesRegex = x.r + this.asInstanceOf[R] + + def withIgnoredFiles(x: Seq[String]): R = + this.ignoredFiles = x.map(createPathForIgnore) + this.asInstanceOf[R] + + def createPathForIgnore(ignore: String): String = + val path = Paths.get(ignore) + if path.isAbsolute then path.toString + else Paths.get(inputPath, ignore).toAbsolutePath.normalize().toString + + var schemaValidation: ValidationMode = ValidationMode.Disabled + + def withSchemaValidation(value: ValidationMode): R = + this.schemaValidation = value + this.asInstanceOf[R] + + def withInheritedFields(config: R): R = + this.inputPath = config.inputPath + this.outputPath = config.outputPath + this.defaultIgnoredFilesRegex = config.defaultIgnoredFilesRegex + this.ignoredFilesRegex = config.ignoredFilesRegex + this.ignoredFiles = config.ignoredFiles + this.asInstanceOf[R] end X2CpgConfig /** Base class for `Main` classes of CPG frontends. @@ -84,225 +84,225 @@ abstract class X2CpgMain[T <: X2CpgConfig[T], X <: X2CpgFrontend[?]]( implicit defaultConfig: T ): - /** method that evaluates frontend with configuration - */ - def run(config: T, frontend: X): Unit - - def main(args: Array[String]): Unit = - X2Cpg.parseCommandLine(args, cmdLineParser, defaultConfig) match - case Some(config) => - try - run(config, frontend) - catch - case ex: Throwable => - println(ex.getMessage) - ex.printStackTrace() - System.exit(1) - case None => - println("Error parsing the command line") - System.exit(1) + /** method that evaluates frontend with configuration + */ + def run(config: T, frontend: X): Unit + + def main(args: Array[String]): Unit = + X2Cpg.parseCommandLine(args, cmdLineParser, defaultConfig) match + case Some(config) => + try + run(config, frontend) + catch + case ex: Throwable => + println(ex.getMessage) + ex.printStackTrace() + System.exit(1) + case None => + println("Error parsing the command line") + System.exit(1) end X2CpgMain /** Trait that represents a CPG generator, where T is the frontend configuration class. */ trait X2CpgFrontend[T <: X2CpgConfig[?]]: - /** Create a CPG according to given configuration. Returns CPG wrapped in a `Try`, making it - * possible to detect and inspect exceptions in CPG generation. To be provided by the frontend. - */ - def createCpg(config: T): Try[Cpg] - - /** Create CPG according to given configuration, printing errors to the console if they occur. - * The CPG is closed and not returned. - */ - def run(config: T): Unit = - withErrorsToConsole(config) { _ => - createCpg(config) match - case Success(cpg) => - cpg.close() - Success(cpg) - case Failure(exception) => - Failure(exception) - } - - /** Create a CPG with default overlays according to given configuration - */ - def createCpgWithOverlays(config: T): Try[Cpg] = - val maybeCpg = createCpg(config) - maybeCpg.map { cpg => - applyDefaultOverlays(cpg) - cpg - } - - /** Create a CPG for code at `inputPath` and apply default overlays. - */ - def createCpgWithOverlays(inputName: String)(implicit defaultConfig: T): Try[Cpg] = - val maybeCpg = createCpg(inputName) - maybeCpg.map { cpg => - applyDefaultOverlays(cpg) - cpg - } - - /** Create a CPG for code at `inputName` (a single location) with default frontend - * configuration. If `outputName` exists, it is the file name of the resulting CPG. Otherwise, - * the CPG is held in memory. - */ - def createCpg(inputName: String, outputName: Option[String])(implicit - defaultConfig: T - ): Try[Cpg] = - val defaultWithInputPath = defaultConfig.withInputPath(inputName).asInstanceOf[T] - val config = if !outputName.contains(X2CpgConfig.defaultOutputPath) then - if outputName.isEmpty then - defaultWithInputPath.withOutputPath("").asInstanceOf[T] - else - defaultWithInputPath.withOutputPath(outputName.get).asInstanceOf[T] - else - defaultWithInputPath - createCpg(config) - - /** Create a CPG in memory for file at `inputName` with default configuration. - */ - def createCpg(inputName: String)(implicit defaultConfig: T): Try[Cpg] = - createCpg(inputName, None)(defaultConfig) + /** Create a CPG according to given configuration. Returns CPG wrapped in a `Try`, making it + * possible to detect and inspect exceptions in CPG generation. To be provided by the frontend. + */ + def createCpg(config: T): Try[Cpg] + + /** Create CPG according to given configuration, printing errors to the console if they occur. The + * CPG is closed and not returned. + */ + def run(config: T): Unit = + withErrorsToConsole(config) { _ => + createCpg(config) match + case Success(cpg) => + cpg.close() + Success(cpg) + case Failure(exception) => + Failure(exception) + } + + /** Create a CPG with default overlays according to given configuration + */ + def createCpgWithOverlays(config: T): Try[Cpg] = + val maybeCpg = createCpg(config) + maybeCpg.map { cpg => + applyDefaultOverlays(cpg) + cpg + } + + /** Create a CPG for code at `inputPath` and apply default overlays. + */ + def createCpgWithOverlays(inputName: String)(implicit defaultConfig: T): Try[Cpg] = + val maybeCpg = createCpg(inputName) + maybeCpg.map { cpg => + applyDefaultOverlays(cpg) + cpg + } + + /** Create a CPG for code at `inputName` (a single location) with default frontend configuration. + * If `outputName` exists, it is the file name of the resulting CPG. Otherwise, the CPG is held + * in memory. + */ + def createCpg(inputName: String, outputName: Option[String])(implicit + defaultConfig: T + ): Try[Cpg] = + val defaultWithInputPath = defaultConfig.withInputPath(inputName).asInstanceOf[T] + val config = if !outputName.contains(X2CpgConfig.defaultOutputPath) then + if outputName.isEmpty then + defaultWithInputPath.withOutputPath("").asInstanceOf[T] + else + defaultWithInputPath.withOutputPath(outputName.get).asInstanceOf[T] + else + defaultWithInputPath + createCpg(config) + + /** Create a CPG in memory for file at `inputName` with default configuration. + */ + def createCpg(inputName: String)(implicit defaultConfig: T): Try[Cpg] = + createCpg(inputName, None)(defaultConfig) end X2CpgFrontend object X2Cpg: - private val logger = LoggerFactory.getLogger(X2Cpg.getClass) - - /** Parse commands line arguments in `args` using an X2Cpg command line parser, extended with - * the frontend specific options in `frontendSpecific` with the initial configuration set to - * `initialConf`. On success, the configuration is returned wrapped into an Option. On failure, - * error messages are printed and, `None` is returned. - */ - def parseCommandLine[R <: X2CpgConfig[R]]( - args: Array[String], - frontendSpecific: OParser[?, R], - initialConf: R - ): Option[R] = - val parser = commandLineParser(frontendSpecific) - OParser.parse(parser, args, initialConf) - - /** Create a command line parser that can be extended to add options specific for the frontend. - */ - private def commandLineParser[R <: X2CpgConfig[R]](frontendSpecific: OParser[?, R]) - : OParser[?, R] = - val builder = OParser.builder[R] - import builder.* - OParser.sequence( - arg[String]("input-dir") - .text("source directory") - .action((x, c) => c.withInputPath(x)), - opt[String]("output") - .abbr("o") - .text("output filename") - .action { (x, c) => - c.withOutputPath(x) - }, - opt[Seq[String]]("exclude") - .valueName(",,...") - .action { (x, c) => - c.ignoredFiles = c.ignoredFiles ++ x.map(c.createPathForIgnore) - c - } - .text( - "files or folders to exclude during CPG generation (paths relative to or absolute paths)" - ), - opt[String]("exclude-regex") - .action { (x, c) => - c.ignoredFilesRegex = x.r - c - } - .text( - "a regex specifying files to exclude during CPG generation (paths relative to are matched)" - ), - opt[Unit]("enable-early-schema-checking") - .action((_, c) => c.withSchemaValidation(ValidationMode.Enabled)) - .text("enables early schema validation during AST creation (disabled by default)"), - help("help").text("display this help message"), - frontendSpecific - ) - end commandLineParser - - /** Create an empty CPG, backed by the file at `optionalOutputPath` or in-memory if - * `optionalOutputPath` is empty. - */ - def newEmptyCpg(optionalOutputPath: Option[String] = None): Cpg = - val odbConfig = optionalOutputPath - .map { outputPath => - val outFile = File(outputPath) - if outputPath != "" && outFile.exists then - logger.debug("Output file exists, removing: " + outputPath) - outFile.delete() - Config.withDefaults.withStorageLocation(outputPath) - } - .getOrElse { - Config.withDefaults() - } - Cpg.withConfig(odbConfig) - - /** Apply function `applyPasses` to a newly created CPG. The CPG is wrapped in a `Try` and - * returned. On failure, the CPG is ensured to be closed. - */ - def withNewEmptyCpg[T <: X2CpgConfig[?]]( - outPath: String, - config: T - )(applyPasses: (Cpg, T) => Unit): Try[Cpg] = - val outputPath = if outPath != "" then Some(outPath) else None - Try { - val cpg = newEmptyCpg(outputPath) - Try { - applyPasses(cpg, config) - } match - case Success(_) => cpg - case Failure(exception) => - cpg.close() - throw exception + private val logger = LoggerFactory.getLogger(X2Cpg.getClass) + + /** Parse commands line arguments in `args` using an X2Cpg command line parser, extended with the + * frontend specific options in `frontendSpecific` with the initial configuration set to + * `initialConf`. On success, the configuration is returned wrapped into an Option. On failure, + * error messages are printed and, `None` is returned. + */ + def parseCommandLine[R <: X2CpgConfig[R]]( + args: Array[String], + frontendSpecific: OParser[?, R], + initialConf: R + ): Option[R] = + val parser = commandLineParser(frontendSpecific) + OParser.parse(parser, args, initialConf) + + /** Create a command line parser that can be extended to add options specific for the frontend. + */ + private def commandLineParser[R <: X2CpgConfig[R]](frontendSpecific: OParser[?, R]) + : OParser[?, R] = + val builder = OParser.builder[R] + import builder.* + OParser.sequence( + arg[String]("input-dir") + .text("source directory") + .action((x, c) => c.withInputPath(x)), + opt[String]("output") + .abbr("o") + .text("output filename") + .action { (x, c) => + c.withOutputPath(x) + }, + opt[Seq[String]]("exclude") + .valueName(",,...") + .action { (x, c) => + c.ignoredFiles = c.ignoredFiles ++ x.map(c.createPathForIgnore) + c + } + .text( + "files or folders to exclude during CPG generation (paths relative to or absolute paths)" + ), + opt[String]("exclude-regex") + .action { (x, c) => + c.ignoredFilesRegex = x.r + c + } + .text( + "a regex specifying files to exclude during CPG generation (paths relative to are matched)" + ), + opt[Unit]("enable-early-schema-checking") + .action((_, c) => c.withSchemaValidation(ValidationMode.Enabled)) + .text("enables early schema validation during AST creation (disabled by default)"), + help("help").text("display this help message"), + frontendSpecific + ) + end commandLineParser + + /** Create an empty CPG, backed by the file at `optionalOutputPath` or in-memory if + * `optionalOutputPath` is empty. + */ + def newEmptyCpg(optionalOutputPath: Option[String] = None): Cpg = + val odbConfig = optionalOutputPath + .map { outputPath => + val outFile = File(outputPath) + if outputPath != "" && outFile.exists then + logger.debug("Output file exists, removing: " + outputPath) + outFile.delete() + Config.withDefaults.withStorageLocation(outputPath) } - - /** Given a function that receives a configuration and returns an arbitrary result type wrapped - * in a `Try`, evaluate the function, printing errors to the console. - */ - def withErrorsToConsole[T <: X2CpgConfig[?]](config: T)(f: T => Try[?]): Try[?] = - f(config) match - case Failure(exception) => - exception.printStackTrace() - Failure(exception) - case Success(v) => - Success(v) - - /** For a CPG generated by a frontend, run the default passes that turn a frontend-CPG into a - * complete CPG. - */ - def applyDefaultOverlays(cpg: Cpg): Unit = - val context = new LayerCreatorContext(cpg) - defaultOverlayCreators().foreach { creator => - creator.run(context) + .getOrElse { + Config.withDefaults() } - - /** This should be the only place where we define the list of default overlays. - */ - def defaultOverlayCreators(): List[LayerCreator] = - List(new Base(), new ControlFlow(), new TypeRelations(), new CallGraph()) - - /** Write `sourceCode` to a temporary file inside a temporary directory. The prefix for the - * temporary directory is given by `tmpDirPrefix`. The suffix for the temporary file is given - * by `suffix`. Both file and directory are deleted on exit. - */ - def writeCodeToFile(sourceCode: String, tmpDirPrefix: String, suffix: String): java.io.File = - val tmpDir = Files.createTempDirectory(tmpDirPrefix).toFile - tmpDir.deleteOnExit() - val codeFile = java.io.File.createTempFile("Test", suffix, tmpDir) - codeFile.deleteOnExit() - new PrintWriter(codeFile): - write(sourceCode); close() - tmpDir - - /** Strips surrounding quotation characters from a string. - * @param s - * the target string. - * @return - * the stripped string. - */ - def stripQuotes(str: String): String = str.replaceAll("^(\"|')|(\"|')$", "") + Cpg.withConfig(odbConfig) + + /** Apply function `applyPasses` to a newly created CPG. The CPG is wrapped in a `Try` and + * returned. On failure, the CPG is ensured to be closed. + */ + def withNewEmptyCpg[T <: X2CpgConfig[?]]( + outPath: String, + config: T + )(applyPasses: (Cpg, T) => Unit): Try[Cpg] = + val outputPath = if outPath != "" then Some(outPath) else None + Try { + val cpg = newEmptyCpg(outputPath) + Try { + applyPasses(cpg, config) + } match + case Success(_) => cpg + case Failure(exception) => + cpg.close() + throw exception + } + + /** Given a function that receives a configuration and returns an arbitrary result type wrapped in + * a `Try`, evaluate the function, printing errors to the console. + */ + def withErrorsToConsole[T <: X2CpgConfig[?]](config: T)(f: T => Try[?]): Try[?] = + f(config) match + case Failure(exception) => + exception.printStackTrace() + Failure(exception) + case Success(v) => + Success(v) + + /** For a CPG generated by a frontend, run the default passes that turn a frontend-CPG into a + * complete CPG. + */ + def applyDefaultOverlays(cpg: Cpg): Unit = + val context = new LayerCreatorContext(cpg) + defaultOverlayCreators().foreach { creator => + creator.run(context) + } + + /** This should be the only place where we define the list of default overlays. + */ + def defaultOverlayCreators(): List[LayerCreator] = + List(new Base(), new ControlFlow(), new TypeRelations(), new CallGraph()) + + /** Write `sourceCode` to a temporary file inside a temporary directory. The prefix for the + * temporary directory is given by `tmpDirPrefix`. The suffix for the temporary file is given by + * `suffix`. Both file and directory are deleted on exit. + */ + def writeCodeToFile(sourceCode: String, tmpDirPrefix: String, suffix: String): java.io.File = + val tmpDir = Files.createTempDirectory(tmpDirPrefix).toFile + tmpDir.deleteOnExit() + val codeFile = java.io.File.createTempFile("Test", suffix, tmpDir) + codeFile.deleteOnExit() + new PrintWriter(codeFile): + write(sourceCode); close() + tmpDir + + /** Strips surrounding quotation characters from a string. + * @param s + * the target string. + * @return + * the stripped string. + */ + def stripQuotes(str: String): String = str.replaceAll("^(\"|')|(\"|')$", "") end X2Cpg diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/datastructures/Global.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/datastructures/Global.scala index 57fffac3..94226df3 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/datastructures/Global.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/datastructures/Global.scala @@ -4,4 +4,4 @@ import java.util.concurrent.ConcurrentHashMap class Global: - val usedTypes: ConcurrentHashMap[String, Boolean] = new ConcurrentHashMap() + val usedTypes: ConcurrentHashMap[String, Boolean] = new ConcurrentHashMap() diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/datastructures/Scope.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/datastructures/Scope.scala index 2c09bc26..f0a3488c 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/datastructures/Scope.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/datastructures/Scope.scala @@ -10,29 +10,29 @@ package io.appthreat.x2cpg.datastructures * Scope type. */ class Scope[I, V, S]: - protected var stack: List[ScopeElement[I, V, S]] = List[ScopeElement[I, V, S]]() + protected var stack: List[ScopeElement[I, V, S]] = List[ScopeElement[I, V, S]]() - def isEmpty: Boolean = - stack.isEmpty + def isEmpty: Boolean = + stack.isEmpty - def pushNewScope(scopeNode: S): Unit = - stack = ScopeElement[I, V, S](scopeNode) :: stack + def pushNewScope(scopeNode: S): Unit = + stack = ScopeElement[I, V, S](scopeNode) :: stack - def popScope(): Option[S] = - stack match - case Nil => None + def popScope(): Option[S] = + stack match + case Nil => None - case head :: tail => - stack = tail - Some(head.scopeNode) + case head :: tail => + stack = tail + Some(head.scopeNode) - def addToScope(identifier: I, variable: V): S = - stack = stack.head.addVariable(identifier, variable) :: stack.tail - stack.head.scopeNode + def addToScope(identifier: I, variable: V): S = + stack = stack.head.addVariable(identifier, variable) :: stack.tail + stack.head.scopeNode - def lookupVariable(identifier: I): Option[V] = - stack.collectFirst { - case scopeElement if scopeElement.variables.contains(identifier) => - scopeElement.variables(identifier) - } + def lookupVariable(identifier: I): Option[V] = + stack.collectFirst { + case scopeElement if scopeElement.variables.contains(identifier) => + scopeElement.variables(identifier) + } end Scope diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/datastructures/ScopeElement.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/datastructures/ScopeElement.scala index fed9d6f4..f6e1fe09 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/datastructures/ScopeElement.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/datastructures/ScopeElement.scala @@ -9,5 +9,5 @@ package io.appthreat.x2cpg.datastructures * Scope type. */ case class ScopeElement[I, V, S](scopeNode: S, variables: Map[I, V] = Map[I, V]()): - def addVariable(identifier: I, variable: V): ScopeElement[I, V, S] = - ScopeElement(scopeNode, variables + (identifier -> variable)) + def addVariable(identifier: I, variable: V): ScopeElement[I, V, S] = + ScopeElement(scopeNode, variables + (identifier -> variable)) diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/datastructures/Stack.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/datastructures/Stack.scala index f197cfd1..bd720e2c 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/datastructures/Stack.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/datastructures/Stack.scala @@ -4,11 +4,11 @@ import scala.collection.mutable object Stack: - type Stack[StackElement] = mutable.ListBuffer[StackElement] + type Stack[StackElement] = mutable.ListBuffer[StackElement] - implicit class StackWrapper[StackElement](val parentStack: Stack[StackElement]) extends AnyVal: - def push(parent: StackElement): Unit = - parentStack.prepend(parent) + implicit class StackWrapper[StackElement](val parentStack: Stack[StackElement]) extends AnyVal: + def push(parent: StackElement): Unit = + parentStack.prepend(parent) - def pop(): Unit = - parentStack.remove(0) + def pop(): Unit = + parentStack.remove(0) diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/layers/Base.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/layers/Base.scala index cefbae62..42838c70 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/layers/Base.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/layers/Base.scala @@ -7,32 +7,32 @@ import io.appthreat.x2cpg.passes.base.* import io.shiftleft.semanticcpg.layers.{LayerCreator, LayerCreatorContext, LayerCreatorOptions} object Base: - val overlayName: String = "base" - val description: String = "base layer (linked frontend CPG)" - def defaultOpts = new LayerCreatorOptions() + val overlayName: String = "base" + val description: String = "base layer (linked frontend CPG)" + def defaultOpts = new LayerCreatorOptions() - def passes(cpg: Cpg): Iterator[CpgPassBase] = Iterator( - new FileCreationPass(cpg), - new NamespaceCreator(cpg), - new TypeDeclStubCreator(cpg), - new MethodStubCreator(cpg), - new ParameterIndexCompatPass(cpg), - new MethodDecoratorPass(cpg), - new AstLinkerPass(cpg), - new ContainsEdgePass(cpg), - new TypeUsagePass(cpg) - ) + def passes(cpg: Cpg): Iterator[CpgPassBase] = Iterator( + new FileCreationPass(cpg), + new NamespaceCreator(cpg), + new TypeDeclStubCreator(cpg), + new MethodStubCreator(cpg), + new ParameterIndexCompatPass(cpg), + new MethodDecoratorPass(cpg), + new AstLinkerPass(cpg), + new ContainsEdgePass(cpg), + new TypeUsagePass(cpg) + ) class Base extends LayerCreator: - override val overlayName: String = Base.overlayName - override val description: String = Base.description + override val overlayName: String = Base.overlayName + override val description: String = Base.description - override def create(context: LayerCreatorContext, storeUndoInfo: Boolean): Unit = - val cpg = context.cpg - cpg.graph.indexManager.createNodePropertyIndex(PropertyNames.FULL_NAME) - Base.passes(cpg).zipWithIndex.foreach { case (pass, index) => - runPass(pass, context, storeUndoInfo, index) - } + override def create(context: LayerCreatorContext, storeUndoInfo: Boolean): Unit = + val cpg = context.cpg + cpg.graph.indexManager.createNodePropertyIndex(PropertyNames.FULL_NAME) + Base.passes(cpg).zipWithIndex.foreach { case (pass, index) => + runPass(pass, context, storeUndoInfo, index) + } - // LayerCreators need one-arg constructor, because they're called by reflection from io.appthreat.console.Run - def this(optionsUnused: LayerCreatorOptions) = this() + // LayerCreators need one-arg constructor, because they're called by reflection from io.appthreat.console.Run + def this(optionsUnused: LayerCreatorOptions) = this() diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/layers/CallGraph.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/layers/CallGraph.scala index 3e104604..fcd9aafa 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/layers/CallGraph.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/layers/CallGraph.scala @@ -6,23 +6,23 @@ import io.appthreat.x2cpg.passes.callgraph.{DynamicCallLinker, MethodRefLinker, import io.shiftleft.semanticcpg.layers.{LayerCreator, LayerCreatorContext, LayerCreatorOptions} object CallGraph: - val overlayName: String = "callgraph" - val description: String = "Call graph layer" - def defaultOpts = new LayerCreatorOptions() + val overlayName: String = "callgraph" + val description: String = "Call graph layer" + def defaultOpts = new LayerCreatorOptions() - def passes(cpg: Cpg): Iterator[CpgPassBase] = - Iterator(new MethodRefLinker(cpg), new StaticCallLinker(cpg), new DynamicCallLinker(cpg)) + def passes(cpg: Cpg): Iterator[CpgPassBase] = + Iterator(new MethodRefLinker(cpg), new StaticCallLinker(cpg), new DynamicCallLinker(cpg)) class CallGraph extends LayerCreator: - override val overlayName: String = CallGraph.overlayName - override val description: String = CallGraph.description - override val dependsOn: List[String] = List(TypeRelations.overlayName) + override val overlayName: String = CallGraph.overlayName + override val description: String = CallGraph.description + override val dependsOn: List[String] = List(TypeRelations.overlayName) - override def create(context: LayerCreatorContext, storeUndoInfo: Boolean): Unit = - val cpg = context.cpg - CallGraph.passes(cpg).zipWithIndex.foreach { case (pass, index) => - runPass(pass, context, storeUndoInfo, index) - } + override def create(context: LayerCreatorContext, storeUndoInfo: Boolean): Unit = + val cpg = context.cpg + CallGraph.passes(cpg).zipWithIndex.foreach { case (pass, index) => + runPass(pass, context, storeUndoInfo, index) + } - // LayerCreators need one-arg constructor, because they're called by reflection from io.appthreat.console.Run - def this(optionsUnused: LayerCreatorOptions) = this() + // LayerCreators need one-arg constructor, because they're called by reflection from io.appthreat.console.Run + def this(optionsUnused: LayerCreatorOptions) = this() diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/layers/ControlFlow.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/layers/ControlFlow.scala index 136a1a73..5acabec6 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/layers/ControlFlow.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/layers/ControlFlow.scala @@ -10,27 +10,27 @@ import io.appthreat.x2cpg.passes.controlflow.codepencegraph.CdgPass import io.shiftleft.semanticcpg.layers.{LayerCreator, LayerCreatorContext, LayerCreatorOptions} object ControlFlow: - val overlayName: String = "controlflow" - val description: String = "Control flow layer (including dominators and CDG edges)" - def defaultOpts = new LayerCreatorOptions() + val overlayName: String = "controlflow" + val description: String = "Control flow layer (including dominators and CDG edges)" + def defaultOpts = new LayerCreatorOptions() - def passes(cpg: Cpg): Iterator[CpgPassBase] = - val cfgCreationPass = cpg.metaData.language.lastOption match - case Some(Languages.GHIDRA) => Iterator[CpgPassBase]() - case Some(Languages.LLVM) => Iterator[CpgPassBase]() - case _ => Iterator[CpgPassBase](new CfgCreationPass(cpg)) - cfgCreationPass ++ Iterator(new CfgDominatorPass(cpg), new CdgPass(cpg)) + def passes(cpg: Cpg): Iterator[CpgPassBase] = + val cfgCreationPass = cpg.metaData.language.lastOption match + case Some(Languages.GHIDRA) => Iterator[CpgPassBase]() + case Some(Languages.LLVM) => Iterator[CpgPassBase]() + case _ => Iterator[CpgPassBase](new CfgCreationPass(cpg)) + cfgCreationPass ++ Iterator(new CfgDominatorPass(cpg), new CdgPass(cpg)) class ControlFlow extends LayerCreator: - override val overlayName: String = ControlFlow.overlayName - override val description: String = ControlFlow.description - override val dependsOn: List[String] = List(Base.overlayName) + override val overlayName: String = ControlFlow.overlayName + override val description: String = ControlFlow.description + override val dependsOn: List[String] = List(Base.overlayName) - override def create(context: LayerCreatorContext, storeUndoInfo: Boolean): Unit = - val cpg = context.cpg - ControlFlow.passes(cpg).zipWithIndex.foreach { case (pass, index) => - runPass(pass, context, storeUndoInfo, index) - } + override def create(context: LayerCreatorContext, storeUndoInfo: Boolean): Unit = + val cpg = context.cpg + ControlFlow.passes(cpg).zipWithIndex.foreach { case (pass, index) => + runPass(pass, context, storeUndoInfo, index) + } - // LayerCreators need one-arg constructor, because they're called by reflection from io.appthreat.console.Run - def this(optionsUnused: LayerCreatorOptions) = this() + // LayerCreators need one-arg constructor, because they're called by reflection from io.appthreat.console.Run + def this(optionsUnused: LayerCreatorOptions) = this() diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/layers/DumpAst.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/layers/DumpAst.scala index 7a7b2b47..e1b35cda 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/layers/DumpAst.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/layers/DumpAst.scala @@ -8,20 +8,20 @@ case class AstDumpOptions(var outDir: String) extends LayerCreatorOptions {} object DumpAst: - val overlayName = "dumpAst" + val overlayName = "dumpAst" - val description = "Dump abstract syntax trees to out/" + val description = "Dump abstract syntax trees to out/" - def defaultOpts: AstDumpOptions = AstDumpOptions("out") + def defaultOpts: AstDumpOptions = AstDumpOptions("out") class DumpAst(options: AstDumpOptions) extends LayerCreator: - override val overlayName: String = DumpAst.overlayName - override val description: String = DumpAst.description - override val storeOverlayName: Boolean = false - - override def create(context: LayerCreatorContext, storeUndoInfo: Boolean): Unit = - val cpg = context.cpg - cpg.method.zipWithIndex.foreach { case (method, i) => - val str = method.dotAst.head - (File(options.outDir) / s"$i-ast.dot").write(str) - } + override val overlayName: String = DumpAst.overlayName + override val description: String = DumpAst.description + override val storeOverlayName: Boolean = false + + override def create(context: LayerCreatorContext, storeUndoInfo: Boolean): Unit = + val cpg = context.cpg + cpg.method.zipWithIndex.foreach { case (method, i) => + val str = method.dotAst.head + (File(options.outDir) / s"$i-ast.dot").write(str) + } diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/layers/DumpCdg.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/layers/DumpCdg.scala index e981900a..38c81beb 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/layers/DumpCdg.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/layers/DumpCdg.scala @@ -8,20 +8,20 @@ case class CdgDumpOptions(var outDir: String) extends LayerCreatorOptions {} object DumpCdg: - val overlayName = "dumpCdg" + val overlayName = "dumpCdg" - val description = "Dump control dependence graph to out/" + val description = "Dump control dependence graph to out/" - def defaultOpts: CdgDumpOptions = CdgDumpOptions("out") + def defaultOpts: CdgDumpOptions = CdgDumpOptions("out") class DumpCdg(options: CdgDumpOptions) extends LayerCreator: - override val overlayName: String = DumpCdg.overlayName - override val description: String = DumpCdg.description - override val storeOverlayName: Boolean = false - - override def create(context: LayerCreatorContext, storeUndoInfo: Boolean): Unit = - val cpg = context.cpg - cpg.method.zipWithIndex.foreach { case (method, i) => - val str = method.dotCdg.head - (File(options.outDir) / s"$i-cdg.dot").write(str) - } + override val overlayName: String = DumpCdg.overlayName + override val description: String = DumpCdg.description + override val storeOverlayName: Boolean = false + + override def create(context: LayerCreatorContext, storeUndoInfo: Boolean): Unit = + val cpg = context.cpg + cpg.method.zipWithIndex.foreach { case (method, i) => + val str = method.dotCdg.head + (File(options.outDir) / s"$i-cdg.dot").write(str) + } diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/layers/DumpCfg.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/layers/DumpCfg.scala index c80adff4..9577f308 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/layers/DumpCfg.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/layers/DumpCfg.scala @@ -8,20 +8,20 @@ case class CfgDumpOptions(var outDir: String) extends LayerCreatorOptions {} object DumpCfg: - val overlayName = "dumpCfg" + val overlayName = "dumpCfg" - val description = "Dump control flow graph to out/" + val description = "Dump control flow graph to out/" - def defaultOpts: CfgDumpOptions = CfgDumpOptions("out") + def defaultOpts: CfgDumpOptions = CfgDumpOptions("out") class DumpCfg(options: CfgDumpOptions) extends LayerCreator: - override val overlayName: String = DumpCfg.overlayName - override val description: String = DumpCfg.description - override val storeOverlayName: Boolean = false - - override def create(context: LayerCreatorContext, storeUndoInfo: Boolean): Unit = - val cpg = context.cpg - cpg.method.zipWithIndex.foreach { case (method, i) => - val str = method.dotCfg.head - (File(options.outDir) / s"$i-cfg.dot").write(str) - } + override val overlayName: String = DumpCfg.overlayName + override val description: String = DumpCfg.description + override val storeOverlayName: Boolean = false + + override def create(context: LayerCreatorContext, storeUndoInfo: Boolean): Unit = + val cpg = context.cpg + cpg.method.zipWithIndex.foreach { case (method, i) => + val str = method.dotCfg.head + (File(options.outDir) / s"$i-cfg.dot").write(str) + } diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/layers/TypeRelations.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/layers/TypeRelations.scala index 419d406c..55233f36 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/layers/TypeRelations.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/layers/TypeRelations.scala @@ -6,23 +6,23 @@ import io.appthreat.x2cpg.passes.typerelations.{AliasLinkerPass, TypeHierarchyPa import io.shiftleft.semanticcpg.layers.{LayerCreator, LayerCreatorContext, LayerCreatorOptions} object TypeRelations: - val overlayName: String = "typerel" - val description: String = "Type relations layer (hierarchy and aliases)" - def defaultOpts = new LayerCreatorOptions() + val overlayName: String = "typerel" + val description: String = "Type relations layer (hierarchy and aliases)" + def defaultOpts = new LayerCreatorOptions() - def passes(cpg: Cpg): Iterator[CpgPassBase] = - Iterator(new TypeHierarchyPass(cpg), new AliasLinkerPass(cpg)) + def passes(cpg: Cpg): Iterator[CpgPassBase] = + Iterator(new TypeHierarchyPass(cpg), new AliasLinkerPass(cpg)) class TypeRelations extends LayerCreator: - override val overlayName: String = TypeRelations.overlayName - override val description: String = TypeRelations.description - override val dependsOn: List[String] = List(Base.overlayName) + override val overlayName: String = TypeRelations.overlayName + override val description: String = TypeRelations.description + override val dependsOn: List[String] = List(Base.overlayName) - override def create(context: LayerCreatorContext, storeUndoInfo: Boolean): Unit = - val cpg = context.cpg - TypeRelations.passes(cpg).zipWithIndex.foreach { case (pass, index) => - runPass(pass, context, storeUndoInfo, index) - } + override def create(context: LayerCreatorContext, storeUndoInfo: Boolean): Unit = + val cpg = context.cpg + TypeRelations.passes(cpg).zipWithIndex.foreach { case (pass, index) => + runPass(pass, context, storeUndoInfo, index) + } - // Layers need one-arg constructor, because they're called by reflection from io.appthreat.console.Run - def this(optionsUnused: LayerCreatorOptions) = this() + // Layers need one-arg constructor, because they're called by reflection from io.appthreat.console.Run + def this(optionsUnused: LayerCreatorOptions) = this() diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/base/AstLinkerPass.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/base/AstLinkerPass.scala index 7b4b6dd6..9ee4c21f 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/base/AstLinkerPass.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/base/AstLinkerPass.scala @@ -9,63 +9,63 @@ import io.shiftleft.semanticcpg.language.* class AstLinkerPass(cpg: Cpg) extends CpgPass(cpg) with LinkingUtil: - override def run(dstGraph: DiffGraphBuilder): Unit = - cpg.method.whereNot(_.astParent).foreach { method => - addAstParent( - method, - method.fullName, - method.astParentType, - method.astParentFullName, - dstGraph - ) - } - cpg.typeDecl.whereNot(_.astParent).foreach { typeDecl => - addAstParent( - typeDecl, - typeDecl.fullName, - typeDecl.astParentType, - typeDecl.astParentFullName, - dstGraph - ) - } - end run + override def run(dstGraph: DiffGraphBuilder): Unit = + cpg.method.whereNot(_.astParent).foreach { method => + addAstParent( + method, + method.fullName, + method.astParentType, + method.astParentFullName, + dstGraph + ) + } + cpg.typeDecl.whereNot(_.astParent).foreach { typeDecl => + addAstParent( + typeDecl, + typeDecl.fullName, + typeDecl.astParentType, + typeDecl.astParentFullName, + dstGraph + ) + } + end run - /** For the given method or type declaration, determine its parent in the AST via the - * AST_PARENT_TYPE and AST_PARENT_FULL_NAME fields and create an AST edge from the parent to - * it. AST creation to methods and type declarations is deferred in frontends in order to allow - * them to process methods/type- declarations independently. - */ - private def addAstParent( - astChild: StoredNode, - astChildFullName: String, - astParentType: String, - astParentFullName: String, - dstGraph: DiffGraphBuilder - ): Unit = - val astParentOption: Option[StoredNode] = - astParentType match - case NodeTypes.METHOD => methodFullNameToNode(cpg, astParentFullName) - case NodeTypes.TYPE_DECL => typeDeclFullNameToNode(cpg, astParentFullName) - case NodeTypes.NAMESPACE_BLOCK => - namespaceBlockFullNameToNode(cpg, astParentFullName) - case _ => - logger.debug( - s"Invalid AST_PARENT_TYPE=$astParentFullName;" + - s" astChild LABEL=${astChild.label};" + - s" astChild FULL_NAME=$astChildFullName" - ) - None + /** For the given method or type declaration, determine its parent in the AST via the + * AST_PARENT_TYPE and AST_PARENT_FULL_NAME fields and create an AST edge from the parent to it. + * AST creation to methods and type declarations is deferred in frontends in order to allow them + * to process methods/type- declarations independently. + */ + private def addAstParent( + astChild: StoredNode, + astChildFullName: String, + astParentType: String, + astParentFullName: String, + dstGraph: DiffGraphBuilder + ): Unit = + val astParentOption: Option[StoredNode] = + astParentType match + case NodeTypes.METHOD => methodFullNameToNode(cpg, astParentFullName) + case NodeTypes.TYPE_DECL => typeDeclFullNameToNode(cpg, astParentFullName) + case NodeTypes.NAMESPACE_BLOCK => + namespaceBlockFullNameToNode(cpg, astParentFullName) + case _ => + logger.debug( + s"Invalid AST_PARENT_TYPE=$astParentFullName;" + + s" astChild LABEL=${astChild.label};" + + s" astChild FULL_NAME=$astChildFullName" + ) + None - astParentOption match - case Some(astParent) => - dstGraph.addEdge(astParent, astChild, EdgeTypes.AST) - case None => - logFailedSrcLookup( - EdgeTypes.AST, - astParentType, - astParentFullName, - astChild.label, - astChild.id.toString - ) - end addAstParent + astParentOption match + case Some(astParent) => + dstGraph.addEdge(astParent, astChild, EdgeTypes.AST) + case None => + logFailedSrcLookup( + EdgeTypes.AST, + astParentType, + astParentFullName, + astChild.label, + astChild.id.toString + ) + end addAstParent end AstLinkerPass diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/base/ContainsEdgePass.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/base/ContainsEdgePass.scala index 88855d61..99757e4a 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/base/ContainsEdgePass.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/base/ContainsEdgePass.scala @@ -12,31 +12,31 @@ import scala.jdk.CollectionConverters.* * which do not provide method stubs and type decl stubs. */ class ContainsEdgePass(cpg: Cpg) extends ConcurrentWriterCpgPass[AstNode](cpg): - import ContainsEdgePass.* + import ContainsEdgePass.* - override def generateParts(): Array[AstNode] = - cpg.graph.nodes(sourceTypes*).asScala.map(_.asInstanceOf[AstNode]).toArray + override def generateParts(): Array[AstNode] = + cpg.graph.nodes(sourceTypes*).asScala.map(_.asInstanceOf[AstNode]).toArray - override def runOnPart(dstGraph: DiffGraphBuilder, source: AstNode): Unit = - // AST is assumed to be a tree. If it contains cycles, then this will give a nice endless loop with OOM - val queue = mutable.ArrayDeque[StoredNode](source) - while queue.nonEmpty do - val parent = queue.removeHead() - for nextNode <- parent._astOut do - if isDestinationType(nextNode) then - dstGraph.addEdge(source, nextNode, EdgeTypes.CONTAINS) - if !isSourceType(nextNode) then queue.append(nextNode) + override def runOnPart(dstGraph: DiffGraphBuilder, source: AstNode): Unit = + // AST is assumed to be a tree. If it contains cycles, then this will give a nice endless loop with OOM + val queue = mutable.ArrayDeque[StoredNode](source) + while queue.nonEmpty do + val parent = queue.removeHead() + for nextNode <- parent._astOut do + if isDestinationType(nextNode) then + dstGraph.addEdge(source, nextNode, EdgeTypes.CONTAINS) + if !isSourceType(nextNode) then queue.append(nextNode) object ContainsEdgePass: - private def isSourceType(node: StoredNode): Boolean = node match - case _: Method | _: TypeDecl | _: File => true - case _ => false + private def isSourceType(node: StoredNode): Boolean = node match + case _: Method | _: TypeDecl | _: File => true + case _ => false - private def isDestinationType(node: StoredNode): Boolean = node match - case _: Block | _: Identifier | _: FieldIdentifier | _: Return | _: Method | _: TypeDecl | _: Call | _: Literal | - _: MethodRef | _: TypeRef | _: ControlStructure | _: JumpTarget | _: Unknown | _: TemplateDom => - true - case _ => false + private def isDestinationType(node: StoredNode): Boolean = node match + case _: Block | _: Identifier | _: FieldIdentifier | _: Return | _: Method | _: TypeDecl | _: Call | _: Literal | + _: MethodRef | _: TypeRef | _: ControlStructure | _: JumpTarget | _: Unknown | _: TemplateDom => + true + case _ => false - private val sourceTypes = List(NodeTypes.METHOD, NodeTypes.TYPE_DECL, NodeTypes.FILE) + private val sourceTypes = List(NodeTypes.METHOD, NodeTypes.TYPE_DECL, NodeTypes.FILE) diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/base/FileCreationPass.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/base/FileCreationPass.scala index e6a9490a..bcbed912 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/base/FileCreationPass.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/base/FileCreationPass.scala @@ -15,43 +15,43 @@ import scala.collection.mutable */ class FileCreationPass(cpg: Cpg) extends CpgPass(cpg) with LinkingUtil: - override def run(dstGraph: DiffGraphBuilder): Unit = - val originalFileNameToNode = mutable.Map.empty[String, StoredNode] - val newFileNameToNode = mutable.Map.empty[String, NewFile] + override def run(dstGraph: DiffGraphBuilder): Unit = + val originalFileNameToNode = mutable.Map.empty[String, StoredNode] + val newFileNameToNode = mutable.Map.empty[String, NewFile] - cpg.file.foreach { node => - originalFileNameToNode += node.name -> node - } + cpg.file.foreach { node => + originalFileNameToNode += node.name -> node + } - def createFileIfDoesNotExist(srcNode: StoredNode, destFullName: String): Unit = - if destFullName != srcNode.propertyDefaultValue(PropertyNames.FILENAME) then - val dstFullName = if destFullName == "" then FileTraversal.UNKNOWN - else destFullName - val newFile = newFileNameToNode.getOrElseUpdate( - dstFullName, { - val file = NewFile().name(dstFullName).order(0) - dstGraph.addNode(file) - file - } - ) - dstGraph.addEdge(srcNode, newFile, EdgeTypes.SOURCE_FILE) + def createFileIfDoesNotExist(srcNode: StoredNode, destFullName: String): Unit = + if destFullName != srcNode.propertyDefaultValue(PropertyNames.FILENAME) then + val dstFullName = if destFullName == "" then FileTraversal.UNKNOWN + else destFullName + val newFile = newFileNameToNode.getOrElseUpdate( + dstFullName, { + val file = NewFile().name(dstFullName).order(0) + dstGraph.addNode(file) + file + } + ) + dstGraph.addEdge(srcNode, newFile, EdgeTypes.SOURCE_FILE) - // Create SOURCE_FILE edges from nodes of various types to FILE - linkToSingle( - cpg, - srcLabels = List( - NodeTypes.NAMESPACE_BLOCK, - NodeTypes.TYPE_DECL, - NodeTypes.METHOD, - NodeTypes.COMMENT - ), - dstNodeLabel = NodeTypes.FILE, - edgeType = EdgeTypes.SOURCE_FILE, - dstNodeMap = x => - originalFileNameToNode.get(x), - dstFullNameKey = PropertyNames.FILENAME, - dstGraph, - Some(createFileIfDoesNotExist) - ) - end run + // Create SOURCE_FILE edges from nodes of various types to FILE + linkToSingle( + cpg, + srcLabels = List( + NodeTypes.NAMESPACE_BLOCK, + NodeTypes.TYPE_DECL, + NodeTypes.METHOD, + NodeTypes.COMMENT + ), + dstNodeLabel = NodeTypes.FILE, + edgeType = EdgeTypes.SOURCE_FILE, + dstNodeMap = x => + originalFileNameToNode.get(x), + dstFullNameKey = PropertyNames.FILENAME, + dstGraph, + Some(createFileIfDoesNotExist) + ) + end run end FileCreationPass diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/base/MethodDecoratorPass.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/base/MethodDecoratorPass.scala index d11f6c9b..9e829134 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/base/MethodDecoratorPass.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/base/MethodDecoratorPass.scala @@ -14,47 +14,47 @@ import org.slf4j.{Logger, LoggerFactory} * method stubs. */ class MethodDecoratorPass(cpg: Cpg) extends CpgPass(cpg): - import MethodDecoratorPass.logger + import MethodDecoratorPass.logger - private var loggedDeprecatedWarning = false - private var loggedMissingTypeFullName = false + private var loggedDeprecatedWarning = false + private var loggedMissingTypeFullName = false - override def run(dstGraph: DiffGraphBuilder): Unit = - cpg.parameter.foreach { parameterIn => - if !parameterIn._parameterLinkOut.hasNext then - val parameterOut = nodes - .NewMethodParameterOut() - .code(parameterIn.code) - .order(parameterIn.order) - .index(parameterIn.index) - .name(parameterIn.name) - .evaluationStrategy(parameterIn.evaluationStrategy) - .typeFullName(parameterIn.typeFullName) - .isVariadic(parameterIn.isVariadic) - .lineNumber(parameterIn.lineNumber) - .columnNumber(parameterIn.columnNumber) + override def run(dstGraph: DiffGraphBuilder): Unit = + cpg.parameter.foreach { parameterIn => + if !parameterIn._parameterLinkOut.hasNext then + val parameterOut = nodes + .NewMethodParameterOut() + .code(parameterIn.code) + .order(parameterIn.order) + .index(parameterIn.index) + .name(parameterIn.name) + .evaluationStrategy(parameterIn.evaluationStrategy) + .typeFullName(parameterIn.typeFullName) + .isVariadic(parameterIn.isVariadic) + .lineNumber(parameterIn.lineNumber) + .columnNumber(parameterIn.columnNumber) - val method = parameterIn.astIn.headOption - if method.isEmpty then - logger.debug("Parameter without method encountered: " + parameterIn.toString) - else - if parameterIn.typeFullName == null then - val evalType = parameterIn.typ - dstGraph.addEdge(parameterOut, evalType, EdgeTypes.EVAL_TYPE) - if !loggedMissingTypeFullName then - logger.debug( - "Using deprecated CPG format with missing TYPE_FULL_NAME on METHOD_PARAMETER_IN nodes." - ) - loggedMissingTypeFullName = true + val method = parameterIn.astIn.headOption + if method.isEmpty then + logger.debug("Parameter without method encountered: " + parameterIn.toString) + else + if parameterIn.typeFullName == null then + val evalType = parameterIn.typ + dstGraph.addEdge(parameterOut, evalType, EdgeTypes.EVAL_TYPE) + if !loggedMissingTypeFullName then + logger.debug( + "Using deprecated CPG format with missing TYPE_FULL_NAME on METHOD_PARAMETER_IN nodes." + ) + loggedMissingTypeFullName = true - dstGraph.addNode(parameterOut) - dstGraph.addEdge(method.get, parameterOut, EdgeTypes.AST) - dstGraph.addEdge(parameterIn, parameterOut, EdgeTypes.PARAMETER_LINK) - else if !loggedDeprecatedWarning then - logger.debug("Using deprecated CPG format with PARAMETER_LINK edges") - loggedDeprecatedWarning = true - } + dstGraph.addNode(parameterOut) + dstGraph.addEdge(method.get, parameterOut, EdgeTypes.AST) + dstGraph.addEdge(parameterIn, parameterOut, EdgeTypes.PARAMETER_LINK) + else if !loggedDeprecatedWarning then + logger.debug("Using deprecated CPG format with PARAMETER_LINK edges") + loggedDeprecatedWarning = true + } end MethodDecoratorPass object MethodDecoratorPass: - private val logger: Logger = LoggerFactory.getLogger(classOf[MethodDecoratorPass]) + private val logger: Logger = LoggerFactory.getLogger(classOf[MethodDecoratorPass]) diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/base/MethodStubCreator.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/base/MethodStubCreator.scala index 19c72046..b871a274 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/base/MethodStubCreator.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/base/MethodStubCreator.scala @@ -24,115 +24,115 @@ case class CallSummary(name: String, signature: String, fullName: String, dispat */ class MethodStubCreator(cpg: Cpg) extends CpgPass(cpg): - // Since the method fullNames for fuzzyc are not unique, we do not have - // a 1to1 relation and may overwrite some values. This is ok for now. - private val methodFullNameToNode = mutable.LinkedHashMap[String, Method]() - private val methodToParameterCount = mutable.LinkedHashMap[CallSummary, Int]() - - override def run(dstGraph: BatchedUpdate.DiffGraphBuilder): Unit = - for method <- cpg.method do - methodFullNameToNode.put(method.fullName, method) - - for call <- cpg.call if call.methodFullName != Defines.DynamicCallUnknownFullName do - methodToParameterCount.put( - CallSummary(call.name, call.signature, call.methodFullName, call.dispatchType), - call.argument.size - ) - - for - (CallSummary(name, signature, fullName, dispatchType), parameterCount) <- - methodToParameterCount - if !methodFullNameToNode.contains(fullName) - do - createMethodStub(name, fullName, signature, dispatchType, parameterCount, dstGraph) - - override def finish(): Unit = - methodFullNameToNode.clear() - methodToParameterCount.clear() - super.finish() + // Since the method fullNames for fuzzyc are not unique, we do not have + // a 1to1 relation and may overwrite some values. This is ok for now. + private val methodFullNameToNode = mutable.LinkedHashMap[String, Method]() + private val methodToParameterCount = mutable.LinkedHashMap[CallSummary, Int]() + + override def run(dstGraph: BatchedUpdate.DiffGraphBuilder): Unit = + for method <- cpg.method do + methodFullNameToNode.put(method.fullName, method) + + for call <- cpg.call if call.methodFullName != Defines.DynamicCallUnknownFullName do + methodToParameterCount.put( + CallSummary(call.name, call.signature, call.methodFullName, call.dispatchType), + call.argument.size + ) + + for + (CallSummary(name, signature, fullName, dispatchType), parameterCount) <- + methodToParameterCount + if !methodFullNameToNode.contains(fullName) + do + createMethodStub(name, fullName, signature, dispatchType, parameterCount, dstGraph) + + override def finish(): Unit = + methodFullNameToNode.clear() + methodToParameterCount.clear() + super.finish() end MethodStubCreator object MethodStubCreator: - private def addLineNumberInfo(methodNode: NewMethod, fullName: String): NewMethod = - val s = fullName.split(":") - if - s.size == 5 && Try { - s(1).toInt - }.isSuccess && Try { - s(2).toInt - }.isSuccess - then - val filename = s(0) - val lineNumber = s(1).toInt - val lineNumberEnd = s(2).toInt - methodNode - .filename(filename) - .lineNumber(lineNumber) - .lineNumberEnd(lineNumberEnd) - else - methodNode - - def createMethodStub( - name: String, - fullName: String, - signature: String, - dispatchType: String, - parameterCount: Int, - dstGraph: DiffGraphBuilder, - isExternal: Boolean = true, - astParentType: String = NodeTypes.NAMESPACE_BLOCK, - astParentFullName: String = "" - ): NewMethod = - val methodNode = NewMethod() - .name(name) - .fullName(fullName) - .isExternal(isExternal) - .signature(signature) - .astParentType(astParentType) - .astParentFullName(astParentFullName) - .order(0) - - addLineNumberInfo(methodNode, fullName) - - dstGraph.addNode(methodNode) - - val firstParameterIndex = dispatchType match - case DispatchTypes.DYNAMIC_DISPATCH => - 0 - case _ => - 1 - - (firstParameterIndex to parameterCount).foreach { parameterOrder => - val nameAndCode = s"p$parameterOrder" - val param = NewMethodParameterIn() - .code(nameAndCode) - .order(parameterOrder) - .name(nameAndCode) - .evaluationStrategy(EvaluationStrategies.BY_VALUE) - .typeFullName("ANY") - - dstGraph.addNode(param) - dstGraph.addEdge(methodNode, param, EdgeTypes.AST) - } - - val blockNode = NewBlock() - .order(1) - .argumentIndex(1) - .typeFullName("ANY") - - dstGraph.addNode(blockNode) - dstGraph.addEdge(methodNode, blockNode, EdgeTypes.AST) - - val methodReturn = NewMethodReturn() - .order(2) - .code("RET") - .evaluationStrategy(EvaluationStrategies.BY_VALUE) - .typeFullName("ANY") - - dstGraph.addNode(methodReturn) - dstGraph.addEdge(methodNode, methodReturn, EdgeTypes.AST) - - methodNode - end createMethodStub + private def addLineNumberInfo(methodNode: NewMethod, fullName: String): NewMethod = + val s = fullName.split(":") + if + s.size == 5 && Try { + s(1).toInt + }.isSuccess && Try { + s(2).toInt + }.isSuccess + then + val filename = s(0) + val lineNumber = s(1).toInt + val lineNumberEnd = s(2).toInt + methodNode + .filename(filename) + .lineNumber(lineNumber) + .lineNumberEnd(lineNumberEnd) + else + methodNode + + def createMethodStub( + name: String, + fullName: String, + signature: String, + dispatchType: String, + parameterCount: Int, + dstGraph: DiffGraphBuilder, + isExternal: Boolean = true, + astParentType: String = NodeTypes.NAMESPACE_BLOCK, + astParentFullName: String = "" + ): NewMethod = + val methodNode = NewMethod() + .name(name) + .fullName(fullName) + .isExternal(isExternal) + .signature(signature) + .astParentType(astParentType) + .astParentFullName(astParentFullName) + .order(0) + + addLineNumberInfo(methodNode, fullName) + + dstGraph.addNode(methodNode) + + val firstParameterIndex = dispatchType match + case DispatchTypes.DYNAMIC_DISPATCH => + 0 + case _ => + 1 + + (firstParameterIndex to parameterCount).foreach { parameterOrder => + val nameAndCode = s"p$parameterOrder" + val param = NewMethodParameterIn() + .code(nameAndCode) + .order(parameterOrder) + .name(nameAndCode) + .evaluationStrategy(EvaluationStrategies.BY_VALUE) + .typeFullName("ANY") + + dstGraph.addNode(param) + dstGraph.addEdge(methodNode, param, EdgeTypes.AST) + } + + val blockNode = NewBlock() + .order(1) + .argumentIndex(1) + .typeFullName("ANY") + + dstGraph.addNode(blockNode) + dstGraph.addEdge(methodNode, blockNode, EdgeTypes.AST) + + val methodReturn = NewMethodReturn() + .order(2) + .code("RET") + .evaluationStrategy(EvaluationStrategies.BY_VALUE) + .typeFullName("ANY") + + dstGraph.addNode(methodReturn) + dstGraph.addEdge(methodNode, methodReturn, EdgeTypes.AST) + + methodNode + end createMethodStub end MethodStubCreator diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/base/NamespaceCreator.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/base/NamespaceCreator.scala index 6d08e718..487ab217 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/base/NamespaceCreator.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/base/NamespaceCreator.scala @@ -12,13 +12,13 @@ import io.shiftleft.semanticcpg.language.* */ class NamespaceCreator(cpg: Cpg) extends CpgPass(cpg): - /** Creates NAMESPACE nodes and connects NAMESPACE_BLOCKs to corresponding NAMESPACE nodes. - */ - override def run(dstGraph: DiffGraphBuilder): Unit = - cpg.namespaceBlock - .groupBy(_.name) - .foreach { case (name: String, blocks) => - val namespace = NewNamespace().name(name) - dstGraph.addNode(namespace) - blocks.foreach(block => dstGraph.addEdge(block, namespace, EdgeTypes.REF)) - } + /** Creates NAMESPACE nodes and connects NAMESPACE_BLOCKs to corresponding NAMESPACE nodes. + */ + override def run(dstGraph: DiffGraphBuilder): Unit = + cpg.namespaceBlock + .groupBy(_.name) + .foreach { case (name: String, blocks) => + val namespace = NewNamespace().name(name) + dstGraph.addNode(namespace) + blocks.foreach(block => dstGraph.addEdge(block, namespace, EdgeTypes.REF)) + } diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/base/ParameterIndexCompatPass.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/base/ParameterIndexCompatPass.scala index f59f8d25..940db7f9 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/base/ParameterIndexCompatPass.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/base/ParameterIndexCompatPass.scala @@ -13,8 +13,8 @@ import io.shiftleft.semanticcpg.language.* */ class ParameterIndexCompatPass(cpg: Cpg) extends CpgPass(cpg): - override def run(diffGraph: BatchedUpdate.DiffGraphBuilder): Unit = - cpg.parameter.foreach { param => - if param.index == PropertyDefaults.Index then - diffGraph.setNodeProperty(param, PropertyNames.INDEX, param.order) - } + override def run(diffGraph: BatchedUpdate.DiffGraphBuilder): Unit = + cpg.parameter.foreach { param => + if param.index == PropertyDefaults.Index then + diffGraph.setNodeProperty(param, PropertyNames.INDEX, param.order) + } diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/base/TypeDeclStubCreator.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/base/TypeDeclStubCreator.scala index f5717955..f5d3be03 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/base/TypeDeclStubCreator.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/base/TypeDeclStubCreator.scala @@ -13,41 +13,41 @@ import io.shiftleft.semanticcpg.language.types.structure.{FileTraversal, Namespa */ class TypeDeclStubCreator(cpg: Cpg) extends CpgPass(cpg): - private var typeDeclFullNameToNode = Map[String, TypeDeclBase]() - - private def privateInit(): Unit = - cpg.typeDecl - .foreach { typeDecl => - typeDeclFullNameToNode += typeDecl.fullName -> typeDecl - } - - override def run(dstGraph: DiffGraphBuilder): Unit = - privateInit() - - cpg.typ - .filterNot(typ => typeDeclFullNameToNode.isDefinedAt(typ.fullName)) - .foreach { typ => - val newTypeDecl = TypeDeclStubCreator.createTypeDeclStub(typ.name, typ.fullName) - typeDeclFullNameToNode += typ.fullName -> newTypeDecl - dstGraph.addNode(newTypeDecl) - } + private var typeDeclFullNameToNode = Map[String, TypeDeclBase]() + + private def privateInit(): Unit = + cpg.typeDecl + .foreach { typeDecl => + typeDeclFullNameToNode += typeDecl.fullName -> typeDecl + } + + override def run(dstGraph: DiffGraphBuilder): Unit = + privateInit() + + cpg.typ + .filterNot(typ => typeDeclFullNameToNode.isDefinedAt(typ.fullName)) + .foreach { typ => + val newTypeDecl = TypeDeclStubCreator.createTypeDeclStub(typ.name, typ.fullName) + typeDeclFullNameToNode += typ.fullName -> newTypeDecl + dstGraph.addNode(newTypeDecl) + } end TypeDeclStubCreator object TypeDeclStubCreator: - def createTypeDeclStub( - name: String, - fullName: String, - isExternal: Boolean = true, - astParentType: String = NodeTypes.NAMESPACE_BLOCK, - astParentFullName: String = NamespaceTraversal.globalNamespaceName, - fileName: String = FileTraversal.UNKNOWN - ): NewTypeDecl = - NewTypeDecl() - .name(name) - .fullName(fullName) - .isExternal(isExternal) - .inheritsFromTypeFullName(IndexedSeq.empty) - .astParentType(astParentType) - .astParentFullName(astParentFullName) - .filename(fileName) + def createTypeDeclStub( + name: String, + fullName: String, + isExternal: Boolean = true, + astParentType: String = NodeTypes.NAMESPACE_BLOCK, + astParentFullName: String = NamespaceTraversal.globalNamespaceName, + fileName: String = FileTraversal.UNKNOWN + ): NewTypeDecl = + NewTypeDecl() + .name(name) + .fullName(fullName) + .isExternal(isExternal) + .inheritsFromTypeFullName(IndexedSeq.empty) + .astParentType(astParentType) + .astParentFullName(astParentFullName) + .filename(fileName) diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/base/TypeUsagePass.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/base/TypeUsagePass.scala index c46d5ca0..ce2dd311 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/base/TypeUsagePass.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/base/TypeUsagePass.scala @@ -7,42 +7,42 @@ import io.shiftleft.passes.CpgPass class TypeUsagePass(cpg: Cpg) extends CpgPass(cpg) with LinkingUtil: - override def run(dstGraph: DiffGraphBuilder): Unit = - // Create REF edges from TYPE nodes to TYPE_DECL - linkToSingle( - cpg, - srcLabels = List(NodeTypes.TYPE), - dstNodeLabel = NodeTypes.TYPE_DECL, - edgeType = EdgeTypes.REF, - dstNodeMap = typeDeclFullNameToNode(cpg, _), - dstFullNameKey = PropertyNames.TYPE_DECL_FULL_NAME, - dstGraph, - None - ) + override def run(dstGraph: DiffGraphBuilder): Unit = + // Create REF edges from TYPE nodes to TYPE_DECL + linkToSingle( + cpg, + srcLabels = List(NodeTypes.TYPE), + dstNodeLabel = NodeTypes.TYPE_DECL, + edgeType = EdgeTypes.REF, + dstNodeMap = typeDeclFullNameToNode(cpg, _), + dstFullNameKey = PropertyNames.TYPE_DECL_FULL_NAME, + dstGraph, + None + ) - // Create EVAL_TYPE edges from nodes of various types to TYPE - linkToSingle( - cpg, - srcLabels = List( - NodeTypes.METHOD_PARAMETER_IN, - NodeTypes.METHOD_PARAMETER_OUT, - NodeTypes.METHOD_RETURN, - NodeTypes.MEMBER, - NodeTypes.LITERAL, - NodeTypes.CALL, - NodeTypes.LOCAL, - NodeTypes.IDENTIFIER, - NodeTypes.BLOCK, - NodeTypes.METHOD_REF, - NodeTypes.TYPE_REF, - NodeTypes.UNKNOWN - ), - dstNodeLabel = NodeTypes.TYPE, - edgeType = EdgeTypes.EVAL_TYPE, - dstNodeMap = typeFullNameToNode(cpg, _), - dstFullNameKey = "TYPE_FULL_NAME", - dstGraph, - None - ) - end run + // Create EVAL_TYPE edges from nodes of various types to TYPE + linkToSingle( + cpg, + srcLabels = List( + NodeTypes.METHOD_PARAMETER_IN, + NodeTypes.METHOD_PARAMETER_OUT, + NodeTypes.METHOD_RETURN, + NodeTypes.MEMBER, + NodeTypes.LITERAL, + NodeTypes.CALL, + NodeTypes.LOCAL, + NodeTypes.IDENTIFIER, + NodeTypes.BLOCK, + NodeTypes.METHOD_REF, + NodeTypes.TYPE_REF, + NodeTypes.UNKNOWN + ), + dstNodeLabel = NodeTypes.TYPE, + edgeType = EdgeTypes.EVAL_TYPE, + dstNodeMap = typeFullNameToNode(cpg, _), + dstFullNameKey = "TYPE_FULL_NAME", + dstGraph, + None + ) + end run end TypeUsagePass diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/callgraph/DynamicCallLinker.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/callgraph/DynamicCallLinker.scala index a4be9677..8f96d7a2 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/callgraph/DynamicCallLinker.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/callgraph/DynamicCallLinker.scala @@ -26,225 +26,225 @@ import scala.jdk.CollectionConverters.* */ class DynamicCallLinker(cpg: Cpg) extends CpgPass(cpg): - import DynamicCallLinker.* - // Used to track potential method candidates for a given method fullname. Since our method full names contain the type - // decl we don't need to specify an addition map to wrap this in. LinkedHashSets are used here to preserve order in - // the best interest of reproducibility during debugging. - private val validM = mutable.Map.empty[String, mutable.LinkedHashSet[String]] - // Used for dynamic programming as subtree's don't need to be recalculated later - private val subclassCache = mutable.Map.empty[String, mutable.LinkedHashSet[String]] - private val superclassCache = mutable.Map.empty[String, mutable.LinkedHashSet[String]] - // Used for O(1) lookups on methods that will work without indexManager - private val typeMap = mutable.Map.empty[String, TypeDecl] - // For linking loose method stubs that cannot be resolved by crawling parent types - private val methodMap = mutable.Map.empty[String, Method] - - private def initMaps(): Unit = - cpg.typeDecl.foreach { typeDecl => - typeMap += (typeDecl.fullName -> typeDecl) - } - cpg.method - .filter(m => !m.name.startsWith("")) - .foreach { method => methodMap += (method.fullName -> method) } - - /** Main method of enhancement - to be implemented by child class - */ - override def run(dstGraph: DiffGraphBuilder): Unit = - // Perform early stopping in the case of no virtual calls - if !cpg.call.exists(_.dispatchType == DispatchTypes.DYNAMIC_DISPATCH) then - return - initMaps() - // ValidM maps class C and method name N to the set of - // func ptrs implementing N for C and its subclasses - for - typeDecl <- cpg.typeDecl; - method <- typeDecl._methodViaAstOut - do - val methodName = method.fullName - val candidates = allSubclasses(typeDecl.fullName).flatMap { staticLookup(_, method) } - validM.put(methodName, candidates) - - subclassCache.clear() - - cpg.call.filter(_.dispatchType == DispatchTypes.DYNAMIC_DISPATCH).foreach { call => - try - linkDynamicCall(call, dstGraph) - catch - case exception: Exception => - throw new RuntimeException(exception) - } - end run - - /** Recursively returns all the sub-types of the given type declaration. Does account for - * circular hierarchies. - */ - private def allSubclasses(typDeclFullName: String): mutable.LinkedHashSet[String] = - inheritanceTraversal(typDeclFullName, subclassCache, inSuperDirection = false) - - /** Recursively returns all the super-types of the given type declaration. Does account for - * circular hierarchies. - */ - private def allSuperClasses(typDeclFullName: String): mutable.LinkedHashSet[String] = - inheritanceTraversal(typDeclFullName, superclassCache, inSuperDirection = true) - - private def inheritanceTraversal( - typDeclFullName: String, - cache: mutable.Map[String, mutable.LinkedHashSet[String]], - inSuperDirection: Boolean - ): mutable.LinkedHashSet[String] = - cache.get(typDeclFullName) match - case Some(superClasses) => superClasses - case None => - val totalSuperclasses = (cpg.typeDecl - .fullNameExact(typDeclFullName) - .headOption match - case Some(curr) => inheritTraversal(curr, inSuperDirection) - case None => mutable.LinkedHashSet.empty - ).map(_.fullName) - cache.put(typDeclFullName, totalSuperclasses) - totalSuperclasses - - private def inheritTraversal( - cur: TypeDecl, - inSuperDirection: Boolean, - visitedNodes: mutable.LinkedHashSet[TypeDecl] = mutable.LinkedHashSet.empty - ): mutable.LinkedHashSet[TypeDecl] = - if visitedNodes.contains(cur) then return visitedNodes - visitedNodes.addOne(cur) - - (if inSuperDirection then - cpg.typeDecl.fullNameExact(cur.fullName).flatMap(_.inheritsFromOut.referencedTypeDecl) - else cpg.typ.fullNameExact(cur.fullName).flatMap(_.inheritsFromIn)) - .collectAll[TypeDecl] - .to(mutable.LinkedHashSet) match - case classesToEval if classesToEval.isEmpty => visitedNodes - case classesToEval => - classesToEval.flatMap(t => inheritTraversal(t, inSuperDirection, visitedNodes)) - visitedNodes - - /** Returns the method from a sub-class implementing a method for the given subclass. - */ - private def staticLookup(subclass: String, method: Method): Option[String] = - typeMap.get(subclass) match - case Some(sc) => - sc._methodViaAstOut - .nameExact(method.name) - .and(_.signatureExact(method.signature)) - .map(_.fullName) - .headOption - case None => None - - private def resolveCallInSuperClasses(call: Call): Boolean = - if !call.methodFullName.contains(":") && !call.methodFullName.contains(".") then - return false - def split(str: String, n: Int) = (str.take(n), str.drop(n + 1)) - val (fullName, signature) = if call.methodFullName.contains(":") then - split(call.methodFullName, call.methodFullName.lastIndexOf(":")) - else split(call.methodFullName, call.methodFullName.lastIndexOf(".")) - val typeDeclFullName = fullName.replace(s".${call.name}", "") - val candidateInheritedMethods = - cpg.typeDecl - .fullNameExact(allSuperClasses(typeDeclFullName).toIndexedSeq*) - .astChildren - .isMethod - .nameExact(call.name) - .and(_.signatureExact(signature)) - .fullName - .l - if candidateInheritedMethods.nonEmpty then + import DynamicCallLinker.* + // Used to track potential method candidates for a given method fullname. Since our method full names contain the type + // decl we don't need to specify an addition map to wrap this in. LinkedHashSets are used here to preserve order in + // the best interest of reproducibility during debugging. + private val validM = mutable.Map.empty[String, mutable.LinkedHashSet[String]] + // Used for dynamic programming as subtree's don't need to be recalculated later + private val subclassCache = mutable.Map.empty[String, mutable.LinkedHashSet[String]] + private val superclassCache = mutable.Map.empty[String, mutable.LinkedHashSet[String]] + // Used for O(1) lookups on methods that will work without indexManager + private val typeMap = mutable.Map.empty[String, TypeDecl] + // For linking loose method stubs that cannot be resolved by crawling parent types + private val methodMap = mutable.Map.empty[String, Method] + + private def initMaps(): Unit = + cpg.typeDecl.foreach { typeDecl => + typeMap += (typeDecl.fullName -> typeDecl) + } + cpg.method + .filter(m => !m.name.startsWith("")) + .foreach { method => methodMap += (method.fullName -> method) } + + /** Main method of enhancement - to be implemented by child class + */ + override def run(dstGraph: DiffGraphBuilder): Unit = + // Perform early stopping in the case of no virtual calls + if !cpg.call.exists(_.dispatchType == DispatchTypes.DYNAMIC_DISPATCH) then + return + initMaps() + // ValidM maps class C and method name N to the set of + // func ptrs implementing N for C and its subclasses + for + typeDecl <- cpg.typeDecl; + method <- typeDecl._methodViaAstOut + do + val methodName = method.fullName + val candidates = allSubclasses(typeDecl.fullName).flatMap { staticLookup(_, method) } + validM.put(methodName, candidates) + + subclassCache.clear() + + cpg.call.filter(_.dispatchType == DispatchTypes.DYNAMIC_DISPATCH).foreach { call => + try + linkDynamicCall(call, dstGraph) + catch + case exception: Exception => + throw new RuntimeException(exception) + } + end run + + /** Recursively returns all the sub-types of the given type declaration. Does account for circular + * hierarchies. + */ + private def allSubclasses(typDeclFullName: String): mutable.LinkedHashSet[String] = + inheritanceTraversal(typDeclFullName, subclassCache, inSuperDirection = false) + + /** Recursively returns all the super-types of the given type declaration. Does account for + * circular hierarchies. + */ + private def allSuperClasses(typDeclFullName: String): mutable.LinkedHashSet[String] = + inheritanceTraversal(typDeclFullName, superclassCache, inSuperDirection = true) + + private def inheritanceTraversal( + typDeclFullName: String, + cache: mutable.Map[String, mutable.LinkedHashSet[String]], + inSuperDirection: Boolean + ): mutable.LinkedHashSet[String] = + cache.get(typDeclFullName) match + case Some(superClasses) => superClasses + case None => + val totalSuperclasses = (cpg.typeDecl + .fullNameExact(typDeclFullName) + .headOption match + case Some(curr) => inheritTraversal(curr, inSuperDirection) + case None => mutable.LinkedHashSet.empty + ).map(_.fullName) + cache.put(typDeclFullName, totalSuperclasses) + totalSuperclasses + + private def inheritTraversal( + cur: TypeDecl, + inSuperDirection: Boolean, + visitedNodes: mutable.LinkedHashSet[TypeDecl] = mutable.LinkedHashSet.empty + ): mutable.LinkedHashSet[TypeDecl] = + if visitedNodes.contains(cur) then return visitedNodes + visitedNodes.addOne(cur) + + (if inSuperDirection then + cpg.typeDecl.fullNameExact(cur.fullName).flatMap(_.inheritsFromOut.referencedTypeDecl) + else cpg.typ.fullNameExact(cur.fullName).flatMap(_.inheritsFromIn)) + .collectAll[TypeDecl] + .to(mutable.LinkedHashSet) match + case classesToEval if classesToEval.isEmpty => visitedNodes + case classesToEval => + classesToEval.flatMap(t => inheritTraversal(t, inSuperDirection, visitedNodes)) + visitedNodes + + /** Returns the method from a sub-class implementing a method for the given subclass. + */ + private def staticLookup(subclass: String, method: Method): Option[String] = + typeMap.get(subclass) match + case Some(sc) => + sc._methodViaAstOut + .nameExact(method.name) + .and(_.signatureExact(method.signature)) + .map(_.fullName) + .headOption + case None => None + + private def resolveCallInSuperClasses(call: Call): Boolean = + if !call.methodFullName.contains(":") && !call.methodFullName.contains(".") then + return false + def split(str: String, n: Int) = (str.take(n), str.drop(n + 1)) + val (fullName, signature) = if call.methodFullName.contains(":") then + split(call.methodFullName, call.methodFullName.lastIndexOf(":")) + else split(call.methodFullName, call.methodFullName.lastIndexOf(".")) + val typeDeclFullName = fullName.replace(s".${call.name}", "") + val candidateInheritedMethods = + cpg.typeDecl + .fullNameExact(allSuperClasses(typeDeclFullName).toIndexedSeq*) + .astChildren + .isMethod + .nameExact(call.name) + .and(_.signatureExact(signature)) + .fullName + .l + if candidateInheritedMethods.nonEmpty then + validM.put( + call.methodFullName, + validM.getOrElse( + call.methodFullName, + mutable.LinkedHashSet.empty + ) ++ mutable.LinkedHashSet.from( + candidateInheritedMethods + ) + ) + true + else if call.methodFullName == ".indirectFieldAccess" then + val calledMethodName = call.argument.last.code + val fieldTypes = call.argument.head.typ.l + fieldTypes.foreach { ft => + val ftSubClasses = allSubclasses(ft.fullName).filterNot(cn => cn == ft.fullName) + if ftSubClasses.nonEmpty then + val candidateSubTypeMethods = cpg.typeDecl.fullNameExact( + ftSubClasses.toIndexedSeq* + ).astChildren.isMethod.name(calledMethodName).fullName.l + if candidateSubTypeMethods.nonEmpty then validM.put( - call.methodFullName, + calledMethodName, validM.getOrElse( - call.methodFullName, + calledMethodName, mutable.LinkedHashSet.empty ) ++ mutable.LinkedHashSet.from( - candidateInheritedMethods + candidateSubTypeMethods ) ) - true - else if call.methodFullName == ".indirectFieldAccess" then - val calledMethodName = call.argument.last.code - val fieldTypes = call.argument.head.typ.l - fieldTypes.foreach { ft => - val ftSubClasses = allSubclasses(ft.fullName).filterNot(cn => cn == ft.fullName) - if ftSubClasses.nonEmpty then - val candidateSubTypeMethods = cpg.typeDecl.fullNameExact( - ftSubClasses.toIndexedSeq* - ).astChildren.isMethod.name(calledMethodName).fullName.l - if candidateSubTypeMethods.nonEmpty then - validM.put( - calledMethodName, - validM.getOrElse( - calledMethodName, - mutable.LinkedHashSet.empty - ) ++ mutable.LinkedHashSet.from( - candidateSubTypeMethods - ) - ) - } - true - else - false - end if - end resolveCallInSuperClasses - - private def linkDynamicCall(call: Call, dstGraph: DiffGraphBuilder): Unit = - // This call linker requires a method full name entry - if call.methodFullName.equals("") || call.methodFullName.equals( - DynamicCallUnknownFullName - ) - then return - // Support for overriding - val resolved = resolveCallInSuperClasses(call) - var methodNameToUse = call.methodFullName - if call.methodFullName.startsWith("") && resolved then - methodNameToUse = call.argument.last.code - validM.get(methodNameToUse) match - case Some(tgts) => - val callsOut = call.callOut.fullName.toSetImmutable - val tgtMs = tgts - .flatMap(destMethod => - if cpg.graph.indexManager.isIndexed(PropertyNames.FULL_NAME) then - methodFullNameToNode(destMethod) - else - cpg.method.fullNameExact(destMethod).headOption - ) - .toSet - // Non-overridden methods linked as external stubs should be excluded if they are detected - val (externalMs, internalMs) = tgtMs.partition(_.isExternal) - (if externalMs.nonEmpty && internalMs.nonEmpty then internalMs else tgtMs) - .foreach { tgtM => - if !callsOut.contains(tgtM.fullName) then - dstGraph.addEdge(call, tgtM, EdgeTypes.CALL) - else - fallbackToStaticResolution(call, dstGraph) - } - case None => - fallbackToStaticResolution(call, dstGraph) - end match - end linkDynamicCall - - /** In the case where the method isn't an internal method and cannot be resolved by crawling - * TYPE_DECL nodes it can be resolved from the map of external methods. - */ - private def fallbackToStaticResolution(call: Call, dstGraph: DiffGraphBuilder): Unit = - methodMap.get(call.methodFullName) match - case Some(tgtM) => dstGraph.addEdge(call, tgtM, EdgeTypes.CALL) - case None => printLinkingError(call) - - private def nodesWithFullName(x: String): Iterable[NodeRef[? <: NodeDb]] = - cpg.graph.indexManager.lookup(PropertyNames.FULL_NAME, x).asScala - - private def methodFullNameToNode(x: String): Option[Method] = - nodesWithFullName(x).collectFirst { case x: Method => x } - - @inline - private def printLinkingError(call: Call): Unit = - logger.debug( - s"Unable to link dynamic CALL with METHOD_FULL_NAME ${call.methodFullName} and context: " + - s"${call.code} @ line ${call.lineNumber}" - ) + } + true + else + false + end if + end resolveCallInSuperClasses + + private def linkDynamicCall(call: Call, dstGraph: DiffGraphBuilder): Unit = + // This call linker requires a method full name entry + if call.methodFullName.equals("") || call.methodFullName.equals( + DynamicCallUnknownFullName + ) + then return + // Support for overriding + val resolved = resolveCallInSuperClasses(call) + var methodNameToUse = call.methodFullName + if call.methodFullName.startsWith("") && resolved then + methodNameToUse = call.argument.last.code + validM.get(methodNameToUse) match + case Some(tgts) => + val callsOut = call.callOut.fullName.toSetImmutable + val tgtMs = tgts + .flatMap(destMethod => + if cpg.graph.indexManager.isIndexed(PropertyNames.FULL_NAME) then + methodFullNameToNode(destMethod) + else + cpg.method.fullNameExact(destMethod).headOption + ) + .toSet + // Non-overridden methods linked as external stubs should be excluded if they are detected + val (externalMs, internalMs) = tgtMs.partition(_.isExternal) + (if externalMs.nonEmpty && internalMs.nonEmpty then internalMs else tgtMs) + .foreach { tgtM => + if !callsOut.contains(tgtM.fullName) then + dstGraph.addEdge(call, tgtM, EdgeTypes.CALL) + else + fallbackToStaticResolution(call, dstGraph) + } + case None => + fallbackToStaticResolution(call, dstGraph) + end match + end linkDynamicCall + + /** In the case where the method isn't an internal method and cannot be resolved by crawling + * TYPE_DECL nodes it can be resolved from the map of external methods. + */ + private def fallbackToStaticResolution(call: Call, dstGraph: DiffGraphBuilder): Unit = + methodMap.get(call.methodFullName) match + case Some(tgtM) => dstGraph.addEdge(call, tgtM, EdgeTypes.CALL) + case None => printLinkingError(call) + + private def nodesWithFullName(x: String): Iterable[NodeRef[? <: NodeDb]] = + cpg.graph.indexManager.lookup(PropertyNames.FULL_NAME, x).asScala + + private def methodFullNameToNode(x: String): Option[Method] = + nodesWithFullName(x).collectFirst { case x: Method => x } + + @inline + private def printLinkingError(call: Call): Unit = + logger.debug( + s"Unable to link dynamic CALL with METHOD_FULL_NAME ${call.methodFullName} and context: " + + s"${call.code} @ line ${call.lineNumber}" + ) end DynamicCallLinker object DynamicCallLinker: - private val logger: Logger = LoggerFactory.getLogger(classOf[DynamicCallLinker]) + private val logger: Logger = LoggerFactory.getLogger(classOf[DynamicCallLinker]) diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/callgraph/MethodRefLinker.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/callgraph/MethodRefLinker.scala index 64b32a42..00a26f81 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/callgraph/MethodRefLinker.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/callgraph/MethodRefLinker.scala @@ -10,15 +10,15 @@ import io.appthreat.x2cpg.utils.LinkingUtil */ class MethodRefLinker(cpg: Cpg) extends CpgPass(cpg) with LinkingUtil: - override def run(dstGraph: DiffGraphBuilder): Unit = - // Create REF edges from METHOD_REFs to METHOD - linkToSingle( - cpg, - srcLabels = List(NodeTypes.METHOD_REF), - dstNodeLabel = NodeTypes.METHOD, - edgeType = EdgeTypes.REF, - dstNodeMap = methodFullNameToNode(cpg, _), - dstFullNameKey = PropertyNames.METHOD_FULL_NAME, - dstGraph, - None - ) + override def run(dstGraph: DiffGraphBuilder): Unit = + // Create REF edges from METHOD_REFs to METHOD + linkToSingle( + cpg, + srcLabels = List(NodeTypes.METHOD_REF), + dstNodeLabel = NodeTypes.METHOD, + edgeType = EdgeTypes.REF, + dstNodeMap = methodFullNameToNode(cpg, _), + dstFullNameKey = PropertyNames.METHOD_FULL_NAME, + dstGraph, + None + ) diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/callgraph/NaiveCallLinker.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/callgraph/NaiveCallLinker.scala index 4fca6e32..5a30fcd8 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/callgraph/NaiveCallLinker.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/callgraph/NaiveCallLinker.scala @@ -12,15 +12,15 @@ import overflowdb.traversal.jIteratortoTraversal */ class NaiveCallLinker(cpg: Cpg) extends CpgPass(cpg): - override def run(dstGraph: DiffGraphBuilder): Unit = - val methodNameToNode = cpg.method.toList.groupBy(_.name) - def calls = cpg.call.filter(_.outE(EdgeTypes.CALL).isEmpty) - for - call <- calls - methods <- methodNameToNode.get(call.name) - method <- methods - do - dstGraph.addEdge(call, method, EdgeTypes.CALL) - // If we can only find one name with the exact match then we can semi-confidently set it as the full name - if methods.sizeIs == 1 then - dstGraph.setNodeProperty(call, PropertyNames.METHOD_FULL_NAME, method.fullName) + override def run(dstGraph: DiffGraphBuilder): Unit = + val methodNameToNode = cpg.method.toList.groupBy(_.name) + def calls = cpg.call.filter(_.outE(EdgeTypes.CALL).isEmpty) + for + call <- calls + methods <- methodNameToNode.get(call.name) + method <- methods + do + dstGraph.addEdge(call, method, EdgeTypes.CALL) + // If we can only find one name with the exact match then we can semi-confidently set it as the full name + if methods.sizeIs == 1 then + dstGraph.setNodeProperty(call, PropertyNames.METHOD_FULL_NAME, method.fullName) diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/callgraph/StaticCallLinker.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/callgraph/StaticCallLinker.scala index 36b93da0..c1e3e8b1 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/callgraph/StaticCallLinker.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/callgraph/StaticCallLinker.scala @@ -11,46 +11,46 @@ import scala.collection.mutable class StaticCallLinker(cpg: Cpg) extends CpgPass(cpg): - import StaticCallLinker.* - private val methodFullNameToNode = mutable.Map.empty[String, List[Method]] + import StaticCallLinker.* + private val methodFullNameToNode = mutable.Map.empty[String, List[Method]] - override def run(dstGraph: DiffGraphBuilder): Unit = + override def run(dstGraph: DiffGraphBuilder): Unit = - cpg.method.foreach { method => - methodFullNameToNode.updateWith(method.fullName) { - case Some(l) => Some(method :: l) - case None => Some(List(method)) - } + cpg.method.foreach { method => + methodFullNameToNode.updateWith(method.fullName) { + case Some(l) => Some(method :: l) + case None => Some(List(method)) } - - cpg.call.foreach { call => - try - linkCall(call, dstGraph) - catch - case exception: Exception => - throw new RuntimeException(exception) - } - - private def linkCall(call: Call, dstGraph: DiffGraphBuilder): Unit = - call.dispatchType match - case DispatchTypes.STATIC_DISPATCH | DispatchTypes.INLINED => - linkStaticCall(call, dstGraph) - case DispatchTypes.DYNAMIC_DISPATCH => - // Do nothing - case _ => logger.debug(s"Unknown dispatch type on dynamic CALL ${call.code}") - - private def linkStaticCall(call: Call, dstGraph: DiffGraphBuilder): Unit = - val resolvedMethodOption = methodFullNameToNode.get(call.methodFullName) - if resolvedMethodOption.isDefined then - resolvedMethodOption.get.foreach { dst => - dstGraph.addEdge(call, dst, EdgeTypes.CALL) - } - else - logger.debug( - s"Unable to link static CALL with METHOD_FULL_NAME ${call.methodFullName}, NAME ${call.name}, " + - s"SIGNATURE ${call.signature}, CODE ${call.code}" - ) + } + + cpg.call.foreach { call => + try + linkCall(call, dstGraph) + catch + case exception: Exception => + throw new RuntimeException(exception) + } + + private def linkCall(call: Call, dstGraph: DiffGraphBuilder): Unit = + call.dispatchType match + case DispatchTypes.STATIC_DISPATCH | DispatchTypes.INLINED => + linkStaticCall(call, dstGraph) + case DispatchTypes.DYNAMIC_DISPATCH => + // Do nothing + case _ => logger.debug(s"Unknown dispatch type on dynamic CALL ${call.code}") + + private def linkStaticCall(call: Call, dstGraph: DiffGraphBuilder): Unit = + val resolvedMethodOption = methodFullNameToNode.get(call.methodFullName) + if resolvedMethodOption.isDefined then + resolvedMethodOption.get.foreach { dst => + dstGraph.addEdge(call, dst, EdgeTypes.CALL) + } + else + logger.debug( + s"Unable to link static CALL with METHOD_FULL_NAME ${call.methodFullName}, NAME ${call.name}, " + + s"SIGNATURE ${call.signature}, CODE ${call.code}" + ) end StaticCallLinker object StaticCallLinker: - private val logger: Logger = LoggerFactory.getLogger(classOf[StaticCallLinker]) + private val logger: Logger = LoggerFactory.getLogger(classOf[StaticCallLinker]) diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/controlflow/CfgCreationPass.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/controlflow/CfgCreationPass.scala index 9ba4af18..42e148d2 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/controlflow/CfgCreationPass.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/controlflow/CfgCreationPass.scala @@ -17,9 +17,9 @@ import io.appthreat.x2cpg.passes.controlflow.cfgcreation.CfgCreator */ class CfgCreationPass(cpg: Cpg) extends ConcurrentWriterCpgPass[Method](cpg): - override def generateParts(): Array[Method] = cpg.method.toArray + override def generateParts(): Array[Method] = cpg.method.toArray - override def runOnPart(diffGraph: DiffGraphBuilder, method: Method): Unit = - val localDiff = new DiffGraphBuilder - new CfgCreator(method, localDiff).run() - diffGraph.absorb(localDiff) + override def runOnPart(diffGraph: DiffGraphBuilder, method: Method): Unit = + val localDiff = new DiffGraphBuilder + new CfgCreator(method, localDiff).run() + diffGraph.absorb(localDiff) diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/controlflow/cfgcreation/Cfg.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/controlflow/cfgcreation/Cfg.scala index 4f9065db..b5fd4290 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/controlflow/cfgcreation/Cfg.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/controlflow/cfgcreation/Cfg.scala @@ -42,152 +42,152 @@ case class Cfg( jumpsToLabel: List[(CfgNode, String)] = List() ): - import Cfg.* - - /** Create a new CFG in which `other` is appended to this CFG. All nodes of the fringe are - * connected to `other`'s entry node and the new fringe is `other`'s fringe. The diffgraphs, - * jumps, and labels are the sum of those present in `this` and `other`. - */ - def ++(other: Cfg): Cfg = - if other == Cfg.empty then - this - else if this == Cfg.empty then - other - else - this.copy( - fringe = other.fringe, - edges = this.edges ++ other.edges ++ - edgesFromFringeTo(this, other.entryNode), - jumpsToLabel = this.jumpsToLabel ++ other.jumpsToLabel, - labeledNodes = this.labeledNodes ++ other.labeledNodes, - breaks = this.breaks ++ other.breaks, - continues = this.continues ++ other.continues, - caseLabels = this.caseLabels ++ other.caseLabels - ) - - def withFringeEdgeType(cfgEdgeType: CfgEdgeType): Cfg = - this.copy(fringe = fringe.map { case (x, _) => (x, cfgEdgeType) }) - - /** Upon completing traversal of the abstract syntax tree, this method creates CFG edges between - * jumps like gotos, labeled breaks, labeled continues and respective labels. - */ - def withResolvedJumpToLabel(): Cfg = - val edges = jumpsToLabel.flatMap { - case (jumpToLabel, label) if label != "*" => - labeledNodes.get(label) match - case Some(labeledNode) => - // TODO set edge type of Always once the backend - // supports it - Some(CfgEdge(jumpToLabel, labeledNode, AlwaysEdge)) - case None => - logger.debug("Unable to wire jump statement. Missing label {}.", label) - None - case (jumpToLabel, _) => - // We come here for: https://gcc.gnu.org/onlinedocs/gcc/Labels-as-Values.html - // For such GOTOs we cannot statically determine the target label. As a quick - // hack we simply put edges to all labels found. This might be an over-taint. - labeledNodes.flatMap { case (_, labeledNode) => - Some(CfgEdge(jumpToLabel, labeledNode, AlwaysEdge)) - } - } - this.copy(edges = this.edges ++ edges) - end withResolvedJumpToLabel + import Cfg.* + + /** Create a new CFG in which `other` is appended to this CFG. All nodes of the fringe are + * connected to `other`'s entry node and the new fringe is `other`'s fringe. The diffgraphs, + * jumps, and labels are the sum of those present in `this` and `other`. + */ + def ++(other: Cfg): Cfg = + if other == Cfg.empty then + this + else if this == Cfg.empty then + other + else + this.copy( + fringe = other.fringe, + edges = this.edges ++ other.edges ++ + edgesFromFringeTo(this, other.entryNode), + jumpsToLabel = this.jumpsToLabel ++ other.jumpsToLabel, + labeledNodes = this.labeledNodes ++ other.labeledNodes, + breaks = this.breaks ++ other.breaks, + continues = this.continues ++ other.continues, + caseLabels = this.caseLabels ++ other.caseLabels + ) + + def withFringeEdgeType(cfgEdgeType: CfgEdgeType): Cfg = + this.copy(fringe = fringe.map { case (x, _) => (x, cfgEdgeType) }) + + /** Upon completing traversal of the abstract syntax tree, this method creates CFG edges between + * jumps like gotos, labeled breaks, labeled continues and respective labels. + */ + def withResolvedJumpToLabel(): Cfg = + val edges = jumpsToLabel.flatMap { + case (jumpToLabel, label) if label != "*" => + labeledNodes.get(label) match + case Some(labeledNode) => + // TODO set edge type of Always once the backend + // supports it + Some(CfgEdge(jumpToLabel, labeledNode, AlwaysEdge)) + case None => + logger.debug("Unable to wire jump statement. Missing label {}.", label) + None + case (jumpToLabel, _) => + // We come here for: https://gcc.gnu.org/onlinedocs/gcc/Labels-as-Values.html + // For such GOTOs we cannot statically determine the target label. As a quick + // hack we simply put edges to all labels found. This might be an over-taint. + labeledNodes.flatMap { case (_, labeledNode) => + Some(CfgEdge(jumpToLabel, labeledNode, AlwaysEdge)) + } + } + this.copy(edges = this.edges ++ edges) + end withResolvedJumpToLabel end Cfg case class CfgEdge(src: CfgNode, dst: CfgNode, edgeType: CfgEdgeType) object Cfg: - private val logger = LoggerFactory.getLogger(getClass) - - def from(cfgs: Cfg*): Cfg = - Cfg( - jumpsToLabel = cfgs.map(_.jumpsToLabel).reduceOption((x, y) => x ++ y).getOrElse(List()), - breaks = cfgs.map(_.breaks).reduceOption((x, y) => x ++ y).getOrElse(List()), - continues = cfgs.map(_.continues).reduceOption((x, y) => x ++ y).getOrElse(List()), - caseLabels = cfgs.map(_.caseLabels).reduceOption((x, y) => x ++ y).getOrElse(List()), - labeledNodes = cfgs.map(_.labeledNodes).reduceOption((x, y) => x ++ y).getOrElse(Map()) - ) - - /** The safe "null" Cfg. - */ - val empty: Cfg = new Cfg() - - trait CfgEdgeType - object TrueEdge extends CfgEdgeType: - override def toString: String = "TrueEdge" - object FalseEdge extends CfgEdgeType: - override def toString: String = "FalseEdge" - object AlwaysEdge extends CfgEdgeType: - override def toString: String = "AlwaysEdge" - object CaseEdge extends CfgEdgeType: - override def toString: String = "CaseEdge" - - /** Create edges from all nodes of cfg's fringe to `node`. - */ - def edgesFromFringeTo(cfg: Cfg, node: Option[CfgNode]): List[CfgEdge] = - edgesFromFringeTo(cfg.fringe, node) - - /** Create edges from all nodes of cfg's fringe to `node`, ignoring fringe edge types and using - * `cfgEdgeType` instead. - */ - def edgesFromFringeTo( - cfg: Cfg, - node: Option[CfgNode], - cfgEdgeType: CfgEdgeType - ): List[CfgEdge] = - edges(cfg.fringe.map(_._1), node, cfgEdgeType) - - /** Create edges from a list (node, cfgEdgeType) pairs to `node` - */ - def edgesFromFringeTo( - fringeElems: List[(CfgNode, CfgEdgeType)], - node: Option[CfgNode] - ): List[CfgEdge] = - fringeElems.flatMap { case (sourceNode, cfgEdgeType) => - node.map { dstNode => - CfgEdge(sourceNode, dstNode, cfgEdgeType) - } - } - - /** Create edges of given type from a list of source nodes to a destination node - */ - def edges( - sources: List[CfgNode], - dstNode: Option[CfgNode], - cfgEdgeType: CfgEdgeType = AlwaysEdge - ): List[CfgEdge] = - edgesToMultiple(sources, dstNode.toList, cfgEdgeType) - - def singleEdge( - source: CfgNode, - destination: CfgNode, - cfgEdgeType: CfgEdgeType = AlwaysEdge - ): List[CfgEdge] = - edgesToMultiple(List(source), List(destination), cfgEdgeType) - - /** Create edges of given type from all nodes in `sources` to `node`. - */ - def edgesToMultiple( - sources: List[CfgNode], - destinations: List[CfgNode], - cfgEdgeType: CfgEdgeType = AlwaysEdge - ): List[CfgEdge] = - sources.flatMap { l => - destinations.map { n => - CfgEdge(l, n, cfgEdgeType) - } - } - - def takeCurrentLevel(nodesWithLevel: List[(CfgNode, Int)]): List[CfgNode] = - nodesWithLevel.collect { - case (node, level) if level == 1 => - node - } - - def reduceAndFilterLevel(nodesWithLevel: List[(CfgNode, Int)]): List[(CfgNode, Int)] = - nodesWithLevel.collect { - case (node, level) if level != 1 => - (node, level - 1) - } + private val logger = LoggerFactory.getLogger(getClass) + + def from(cfgs: Cfg*): Cfg = + Cfg( + jumpsToLabel = cfgs.map(_.jumpsToLabel).reduceOption((x, y) => x ++ y).getOrElse(List()), + breaks = cfgs.map(_.breaks).reduceOption((x, y) => x ++ y).getOrElse(List()), + continues = cfgs.map(_.continues).reduceOption((x, y) => x ++ y).getOrElse(List()), + caseLabels = cfgs.map(_.caseLabels).reduceOption((x, y) => x ++ y).getOrElse(List()), + labeledNodes = cfgs.map(_.labeledNodes).reduceOption((x, y) => x ++ y).getOrElse(Map()) + ) + + /** The safe "null" Cfg. + */ + val empty: Cfg = new Cfg() + + trait CfgEdgeType + object TrueEdge extends CfgEdgeType: + override def toString: String = "TrueEdge" + object FalseEdge extends CfgEdgeType: + override def toString: String = "FalseEdge" + object AlwaysEdge extends CfgEdgeType: + override def toString: String = "AlwaysEdge" + object CaseEdge extends CfgEdgeType: + override def toString: String = "CaseEdge" + + /** Create edges from all nodes of cfg's fringe to `node`. + */ + def edgesFromFringeTo(cfg: Cfg, node: Option[CfgNode]): List[CfgEdge] = + edgesFromFringeTo(cfg.fringe, node) + + /** Create edges from all nodes of cfg's fringe to `node`, ignoring fringe edge types and using + * `cfgEdgeType` instead. + */ + def edgesFromFringeTo( + cfg: Cfg, + node: Option[CfgNode], + cfgEdgeType: CfgEdgeType + ): List[CfgEdge] = + edges(cfg.fringe.map(_._1), node, cfgEdgeType) + + /** Create edges from a list (node, cfgEdgeType) pairs to `node` + */ + def edgesFromFringeTo( + fringeElems: List[(CfgNode, CfgEdgeType)], + node: Option[CfgNode] + ): List[CfgEdge] = + fringeElems.flatMap { case (sourceNode, cfgEdgeType) => + node.map { dstNode => + CfgEdge(sourceNode, dstNode, cfgEdgeType) + } + } + + /** Create edges of given type from a list of source nodes to a destination node + */ + def edges( + sources: List[CfgNode], + dstNode: Option[CfgNode], + cfgEdgeType: CfgEdgeType = AlwaysEdge + ): List[CfgEdge] = + edgesToMultiple(sources, dstNode.toList, cfgEdgeType) + + def singleEdge( + source: CfgNode, + destination: CfgNode, + cfgEdgeType: CfgEdgeType = AlwaysEdge + ): List[CfgEdge] = + edgesToMultiple(List(source), List(destination), cfgEdgeType) + + /** Create edges of given type from all nodes in `sources` to `node`. + */ + def edgesToMultiple( + sources: List[CfgNode], + destinations: List[CfgNode], + cfgEdgeType: CfgEdgeType = AlwaysEdge + ): List[CfgEdge] = + sources.flatMap { l => + destinations.map { n => + CfgEdge(l, n, cfgEdgeType) + } + } + + def takeCurrentLevel(nodesWithLevel: List[(CfgNode, Int)]): List[CfgNode] = + nodesWithLevel.collect { + case (node, level) if level == 1 => + node + } + + def reduceAndFilterLevel(nodesWithLevel: List[(CfgNode, Int)]): List[(CfgNode, Int)] = + nodesWithLevel.collect { + case (node, level) if level != 1 => + (node, level - 1) + } end Cfg diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/controlflow/cfgcreation/CfgCreator.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/controlflow/cfgcreation/CfgCreator.scala index c28a79e9..996fa543 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/controlflow/cfgcreation/CfgCreator.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/controlflow/cfgcreation/CfgCreator.scala @@ -48,575 +48,574 @@ import overflowdb.BatchedUpdate.DiffGraphBuilder */ class CfgCreator(entryNode: Method, diffGraph: DiffGraphBuilder): - import io.appthreat.x2cpg.passes.controlflow.cfgcreation.Cfg.* - import io.appthreat.x2cpg.passes.controlflow.cfgcreation.CfgCreator.* - - /** Control flow graph definitions often feature a designated entry and exit node for each - * method. While these nodes are no-ops from a computational point of view, they are useful to - * guarantee that a method has exactly one entry and one exit. - * - * For the CPG-based control flow graph, we do not need to introduce fake entry and exit node. - * Instead, we can use the METHOD and METHOD_RETURN nodes as entry and exit nodes respectively. - * Note that METHOD_RETURN nodes are the nodes representing formal return parameters, of which - * there exists exactly one per method. - */ - private val exitNode: MethodReturn = entryNode.methodReturn - - /** We return the CFG as a sequence of Diff Graphs that is calculated by first obtaining the CFG - * for the method and then resolving gotos. - */ - def run(): Unit = - cfgForMethod(entryNode).withResolvedJumpToLabel().edges.foreach { edge => - // TODO: we are ignoring edge.edgeType because the - // CFG spec doesn't define an edge type at the moment - diffGraph.addEdge(edge.src, edge.dst, EdgeTypes.CFG) - } - - /** Conversion of a method to a CFG, showing the decomposition of the control flow graph - * generation problem into that of translating sub trees according to the node type. In the - * particular case of a method, the CFG is obtained by creating a CFG containing the single - * method node and a fringe containing the node and an outgoing AlwaysEdge, to the CFG obtained - * by translating child CFGs one by one and appending them. - */ - private def cfgForMethod(node: Method): Cfg = - cfgForSingleNode(node) ++ cfgForChildren(node) - - /** For any single AST node, we can construct a CFG containing that single node by setting it as - * the entry node and placing it in the fringe. - */ - private def cfgForSingleNode(node: CfgNode): Cfg = - Cfg(entryNode = Option(node), fringe = List((node, AlwaysEdge))) - - /** The CFG for all children is obtained by translating child ASTs one by one from left to right - * and appending them. - */ - private def cfgForChildren(node: AstNode): Cfg = - node.astChildren.l.map(cfgFor).reduceOption((x, y) => x ++ y).getOrElse(Cfg.empty) - - /** Returns true if this node is a child to some `try` control structure, false if otherwise. - */ - private def withinATryBlock(x: AstNode): Boolean = - x.inAst.isControlStructure.exists(_.controlStructureType == ControlStructureTypes.TRY) - - /** This method dispatches AST nodes by type and calls corresponding conversion methods. - */ - protected def cfgFor(node: AstNode): Cfg = - node match - case _: Method | _: MethodParameterIn | _: Modifier | _: Local | _: TypeDecl | _: Member => - Cfg.empty - case _: MethodRef | _: TypeRef | _: MethodReturn => - cfgForSingleNode(node.asInstanceOf[CfgNode]) - case controlStructure: ControlStructure => - cfgForControlStructure(controlStructure) - case jumpTarget: JumpTarget => - cfgForJumpTarget(jumpTarget) - case ret: Return if withinATryBlock(ret) => - cfgForReturn(ret, inheritFringe = true) - case ret: Return => - cfgForReturn(ret) - case call: Call if call.name == Operators.logicalAnd => - cfgForAndExpression(call) - case call: Call if call.name == Operators.logicalOr => - cfgForOrExpression(call) - case call: Call if call.name == Operators.conditional => - cfgForConditionalExpression(call) - case call: Call if call.dispatchType == DispatchTypes.INLINED => - cfgForInlinedCall(call) - case block: Block if blockMatches(block) => - cfgForChildren(block) - case _: Block => - cfgForChildren(node) ++ cfgForSingleNode(node.asInstanceOf[CfgNode]) - case _: Call | _: FieldIdentifier | _: Identifier | _: Literal | _: Block | _: Unknown => - cfgForChildren(node) ++ cfgForSingleNode(node.asInstanceOf[CfgNode]) - case _ => - cfgForChildren(node) - - private def isLogicalOperator(node: AstNode): Boolean = node match - case call: Call => - call.name == Operators.conditional || call.name == Operators.logicalOr || call.name == Operators.logicalAnd - case _ => false - - private def isInlinedCall(node: AstNode): Boolean = node match - case call: Call => call.dispatchType == DispatchTypes.INLINED - case _ => false - - /** Only include block nodes that do not describe the entire method body or the bodies of - * control structures or inlined calls or logical operators. - */ - private def blockMatches(block: Block): Boolean = - if block._astIn.hasNext then - val parentNode = block.astParent - parentNode.isMethod || parentNode.isControlStructure || isLogicalOperator( - parentNode - ) || isInlinedCall(parentNode) - else false - - /** A second layer of dispatching for control structures. This could as well be part of `cfgFor` - * and has only been placed into a separate function to increase readability. - */ - protected def cfgForControlStructure(node: ControlStructure): Cfg = - node.controlStructureType match - case ControlStructureTypes.BREAK => - cfgForBreakStatement(node) - case ControlStructureTypes.CONTINUE => - cfgForContinueStatement(node) - case ControlStructureTypes.WHILE => - cfgForWhileStatement(node) - case ControlStructureTypes.DO => - cfgForDoStatement(node) - case ControlStructureTypes.FOR => - cfgForForStatement(node) - case ControlStructureTypes.GOTO => - cfgForGotoStatement(node) - case ControlStructureTypes.IF => - cfgForIfStatement(node) - case ControlStructureTypes.ELSE => - cfgForChildren(node) - case ControlStructureTypes.SWITCH => - cfgForSwitchStatement(node) - case ControlStructureTypes.TRY => - cfgForTryStatement(node) - case ControlStructureTypes.MATCH => - cfgForMatchExpression(node) - case _ => - Cfg.empty - - /** The CFG for a break/continue statements contains only the break/continue statement as a - * single entry node. The fringe is empty, that is, appending another CFG to the break - * statement will not result in the creation of an edge from the break statement to the entry - * point of the other CFG. Labeled breaks are treated like gotos and are added to - * "jumpsToLabel". - */ - protected def cfgForBreakStatement(node: ControlStructure): Cfg = - node.astChildren.find(_.order == 1) match - case Some(jumpLabel: JumpLabel) => - val labelName = jumpLabel.name - Cfg(entryNode = Option(node), jumpsToLabel = List((node, labelName))) - case Some(literal: Literal) => - // In case we find a literal, it is assumed to be an integer literal which - // indicates how many loop/switch levels the break shall apply to. - val numberOfLevels = Integer.valueOf(literal.code) - Cfg(entryNode = Option(node), breaks = List((node, numberOfLevels))) - case Some(_) => - throw new NotImplementedError( - "Only jump labels and integer literals are currently supported for break statements." - ) - case None => - Cfg(entryNode = Option(node), breaks = List((node, 1))) - - protected def cfgForContinueStatement(node: ControlStructure): Cfg = - node.astChildren.find(_.order == 1) match - case Some(jumpLabel: JumpLabel) => - val labelName = jumpLabel.name - Cfg(entryNode = Option(node), jumpsToLabel = List((node, labelName))) - case Some(literal: Literal) => - // In case we find a literal, it is assumed to be an integer literal which - // indicates how many loop levels the continue shall apply to. - val numberOfLevels = Integer.valueOf(literal.code) - Cfg(entryNode = Option(node), continues = List((node, numberOfLevels))) - case Some(_) => - throw new NotImplementedError( - "Only jump labels and integer literals are currently supported for continue statements." - ) - case None => - Cfg(entryNode = Option(node), continues = List((node, 1))) - - /** Jump targets ("labels") are included in the CFG. As these should be connected to the next - * appended CFG, we specify that the label node is both the entry node and the only node in the - * fringe. This is achieved by calling `cfgForSingleNode` on the label node. Just like for - * breaks and continues, we record labels. We store case/default labels separately from other - * labels, but that is not a relevant implementation detail. - */ - protected def cfgForJumpTarget(n: JumpTarget): Cfg = - val labelName = n.name - val cfg = cfgForSingleNode(n) - if labelName.startsWith("case") || labelName.startsWith("default") then - cfg.copy(caseLabels = List(n)) - else - cfg.copy(labeledNodes = Map(labelName -> n)) - - /** A CFG for a goto statement is one containing the goto node as an entry node and an empty - * fringe. Moreover, we store the goto for dispatching with `withResolvedJumpToLabel` once the - * CFG for the entire method has been calculated. - */ - protected def cfgForGotoStatement(node: ControlStructure): Cfg = - node.astChildren.find(_.order == 1) match - case Some(jumpLabel) => - val labelName = jumpLabel.asInstanceOf[JumpLabel].name - Cfg(entryNode = Option(node), jumpsToLabel = List((node, labelName))) - case None => - // Support for old format where the label name is parsed from the code field. - val target = node.code.split(" ").lastOption.map(x => x.slice(0, x.length - 1)) - target.map(t => - Cfg(entryNode = Some(node), jumpsToLabel = List((node, t))) - ).getOrElse(Cfg.empty) - - /** Return statements may contain expressions as return values, and therefore, the CFG for a - * return statement consists of the CFG for calculation of that expression, appended to a CFG - * containing only the return node, connected with a single edge to the method exit node. The - * fringe is empty. - * - * @param inheritFringe - * indicates if the resulting Cfg object must contain the fringe value of the return value's - * children. - */ - protected def cfgForReturn(actualRet: Return, inheritFringe: Boolean = false): Cfg = - val childrenCfg = cfgForChildren(actualRet) - childrenCfg ++ - Cfg( - entryNode = Option(actualRet), - edges = singleEdge(actualRet, exitNode), - if inheritFringe then childrenCfg.fringe else List() - ) - - /** The right hand side of a logical AND expression is only evaluated if the left hand side is - * true as the entire expression can only be true if both expressions are true. This is encoded - * in the corresponding control flow graph by creating control flow graphs for the left and - * right hand expressions and appending the two, where the fringe edge type of the left CFG is - * `TrueEdge`. - */ - protected def cfgForAndExpression(call: Call): Cfg = - val leftCfg = cfgFor(call.argument(1)) - val rightCfg = cfgFor(call.argument(2)) - val diffGraphs = edgesFromFringeTo( - leftCfg, - rightCfg.entryNode, - TrueEdge - ) ++ leftCfg.edges ++ rightCfg.edges - Cfg - .from(leftCfg, rightCfg) - .copy( - entryNode = leftCfg.entryNode, - edges = diffGraphs, - fringe = leftCfg.fringe ++ rightCfg.fringe - ) ++ cfgForSingleNode(call) - - /** Same construction recipe as for the AND expression, just that the fringe edge type of the - * left CFG is `FalseEdge`. - */ - protected def cfgForOrExpression(call: Call): Cfg = - val leftCfg = cfgFor(call.argument(1)) - val rightCfg = cfgFor(call.argument(2)) - val diffGraphs = edgesFromFringeTo( - leftCfg, - rightCfg.entryNode, - FalseEdge - ) ++ leftCfg.edges ++ rightCfg.edges - Cfg - .from(leftCfg, rightCfg) - .copy( - entryNode = leftCfg.entryNode, - edges = diffGraphs, - fringe = leftCfg.fringe ++ rightCfg.fringe - ) ++ cfgForSingleNode(call) - - /** A conditional expression is of the form `condition ? trueExpr ; falseExpr` where both - * `trueExpr` and `falseExpr` are optional. We create the corresponding CFGs by creating CFGs - * for the three expressions and adding edges between them. The new entry node is the condition - * entry node. - */ - protected def cfgForConditionalExpression(call: Call): Cfg = - val conditionCfg = cfgFor(call.argument(1)) - val trueCfg = call.argumentOption(2).map(cfgFor).getOrElse(Cfg.empty) - val falseCfg = call.argumentOption(3).map(cfgFor).getOrElse(Cfg.empty) - val diffGraphs = edgesFromFringeTo(conditionCfg, trueCfg.entryNode, TrueEdge) ++ - edgesFromFringeTo(conditionCfg, falseCfg.entryNode, FalseEdge) - - val trueFridge = if trueCfg.entryNode.isDefined then - trueCfg.fringe - else - conditionCfg.fringe.withEdgeType(TrueEdge) - val falseFridge = if falseCfg.entryNode.isDefined then - falseCfg.fringe - else - conditionCfg.fringe.withEdgeType(FalseEdge) - - Cfg - .from(conditionCfg, trueCfg, falseCfg) - .copy( - entryNode = conditionCfg.entryNode, - edges = conditionCfg.edges ++ trueCfg.edges ++ falseCfg.edges ++ diffGraphs, - fringe = trueFridge ++ falseFridge - ) ++ cfgForSingleNode(call) - end cfgForConditionalExpression - - /** For macros, the AST contains a CALL node, along with child sub trees for all arguments, and - * a final sub tree that contains the inlined code. The corresponding CFG consists of the CFG - * for the call, an edge to the exit and an edge to the CFG of the inlined code. We choose this - * representation because it allows both queries that use the macro reference as well as - * queries that reference the inline code to be chosen as sources/sinks in data flow queries. - */ - def cfgForInlinedCall(call: Call): Cfg = - val cfgForMacroCall = call.argument.l - .map(cfgFor) - .reduceOption((x, y) => x ++ y) - .getOrElse(Cfg.empty) ++ cfgForSingleNode(call) - val cfgForExpansion = call.astChildren.lastOption.map(cfgFor).getOrElse(Cfg.empty) - val cfg = Cfg - .from(cfgForMacroCall, cfgForExpansion) - .copy( - entryNode = cfgForMacroCall.entryNode, - edges = - cfgForMacroCall.edges ++ cfgForExpansion.edges ++ cfgForExpansion.entryNode.toList - .flatMap(x => singleEdge(call, x)), - fringe = cfgForMacroCall.fringe ++ cfgForExpansion.fringe - ) - cfg - - /** A for statement is of the form `for(initExpr; condition; loopExpr) body` and all four - * components may be empty. The sequence (condition - body - loopExpr) form the inner part of - * the loop and we calculate the corresponding CFG `innerCfg` so that it is no longer relevant - * which of these three actually exist and we still have an entry node for the loop and a - * fringe. - */ - protected def cfgForForStatement(node: ControlStructure): Cfg = - val children = node.astChildren.l - val nLocals = children.count(_.isLocal) - val initExprCfg = children.find(_.order == nLocals + 1).map(cfgFor).getOrElse(Cfg.empty) - val conditionCfg = children.find(_.order == nLocals + 2).map(cfgFor).getOrElse(Cfg.empty) - val loopExprCfg = children.find(_.order == nLocals + 3).map(cfgFor).getOrElse(Cfg.empty) - val bodyCfg = children.find(_.order == nLocals + 4).map(cfgFor).getOrElse(Cfg.empty) - - val innerCfg = conditionCfg ++ bodyCfg ++ loopExprCfg - val entryNode = (initExprCfg ++ innerCfg).entryNode - - val newEdges = edgesFromFringeTo(initExprCfg, innerCfg.entryNode) ++ - edgesFromFringeTo(innerCfg, innerCfg.entryNode) ++ - edgesFromFringeTo(conditionCfg, bodyCfg.entryNode, TrueEdge) ++ { - if loopExprCfg.entryNode.isDefined then - edges(takeCurrentLevel(bodyCfg.continues), loopExprCfg.entryNode) - else - edges(takeCurrentLevel(bodyCfg.continues), innerCfg.entryNode) - } - - Cfg - .from(initExprCfg, conditionCfg, loopExprCfg, bodyCfg) - .copy( - entryNode = entryNode, - edges = newEdges ++ initExprCfg.edges ++ innerCfg.edges, - fringe = conditionCfg.fringe.withEdgeType(FalseEdge) ++ takeCurrentLevel( - bodyCfg.breaks - ).map((_, AlwaysEdge)), - breaks = reduceAndFilterLevel(bodyCfg.breaks), - continues = reduceAndFilterLevel(bodyCfg.continues) - ) - end cfgForForStatement - - /** A Do-Statement is of the form `do body while(condition)` where body may be empty. We again - * first calculate the inner CFG as bodyCfg ++ conditionCfg and then connect edges according to - * the semantics of do-while. - */ - protected def cfgForDoStatement(node: ControlStructure): Cfg = - val bodyCfg = node.astChildren.where(_.order(1)).headOption.map(cfgFor).getOrElse(Cfg.empty) - val conditionCfg = node.condition.headOption.map(cfgFor).getOrElse(Cfg.empty) - val innerCfg = bodyCfg ++ conditionCfg - - val diffGraphs = - edges(takeCurrentLevel(bodyCfg.continues), conditionCfg.entryNode) ++ - edgesFromFringeTo(bodyCfg, conditionCfg.entryNode) ++ - edgesFromFringeTo(conditionCfg, innerCfg.entryNode, TrueEdge) - - Cfg - .from(bodyCfg, conditionCfg, innerCfg) - .copy( - entryNode = if bodyCfg != Cfg.empty then bodyCfg.entryNode - else conditionCfg.entryNode, - edges = diffGraphs ++ bodyCfg.edges ++ conditionCfg.edges, - fringe = conditionCfg.fringe.withEdgeType(FalseEdge) ++ takeCurrentLevel( - bodyCfg.breaks - ).map((_, AlwaysEdge)), - breaks = reduceAndFilterLevel(bodyCfg.breaks), - continues = reduceAndFilterLevel(bodyCfg.continues) + import io.appthreat.x2cpg.passes.controlflow.cfgcreation.Cfg.* + import io.appthreat.x2cpg.passes.controlflow.cfgcreation.CfgCreator.* + + /** Control flow graph definitions often feature a designated entry and exit node for each method. + * While these nodes are no-ops from a computational point of view, they are useful to guarantee + * that a method has exactly one entry and one exit. + * + * For the CPG-based control flow graph, we do not need to introduce fake entry and exit node. + * Instead, we can use the METHOD and METHOD_RETURN nodes as entry and exit nodes respectively. + * Note that METHOD_RETURN nodes are the nodes representing formal return parameters, of which + * there exists exactly one per method. + */ + private val exitNode: MethodReturn = entryNode.methodReturn + + /** We return the CFG as a sequence of Diff Graphs that is calculated by first obtaining the CFG + * for the method and then resolving gotos. + */ + def run(): Unit = + cfgForMethod(entryNode).withResolvedJumpToLabel().edges.foreach { edge => + // TODO: we are ignoring edge.edgeType because the + // CFG spec doesn't define an edge type at the moment + diffGraph.addEdge(edge.src, edge.dst, EdgeTypes.CFG) + } + + /** Conversion of a method to a CFG, showing the decomposition of the control flow graph + * generation problem into that of translating sub trees according to the node type. In the + * particular case of a method, the CFG is obtained by creating a CFG containing the single + * method node and a fringe containing the node and an outgoing AlwaysEdge, to the CFG obtained + * by translating child CFGs one by one and appending them. + */ + private def cfgForMethod(node: Method): Cfg = + cfgForSingleNode(node) ++ cfgForChildren(node) + + /** For any single AST node, we can construct a CFG containing that single node by setting it as + * the entry node and placing it in the fringe. + */ + private def cfgForSingleNode(node: CfgNode): Cfg = + Cfg(entryNode = Option(node), fringe = List((node, AlwaysEdge))) + + /** The CFG for all children is obtained by translating child ASTs one by one from left to right + * and appending them. + */ + private def cfgForChildren(node: AstNode): Cfg = + node.astChildren.l.map(cfgFor).reduceOption((x, y) => x ++ y).getOrElse(Cfg.empty) + + /** Returns true if this node is a child to some `try` control structure, false if otherwise. + */ + private def withinATryBlock(x: AstNode): Boolean = + x.inAst.isControlStructure.exists(_.controlStructureType == ControlStructureTypes.TRY) + + /** This method dispatches AST nodes by type and calls corresponding conversion methods. + */ + protected def cfgFor(node: AstNode): Cfg = + node match + case _: Method | _: MethodParameterIn | _: Modifier | _: Local | _: TypeDecl | _: Member => + Cfg.empty + case _: MethodRef | _: TypeRef | _: MethodReturn => + cfgForSingleNode(node.asInstanceOf[CfgNode]) + case controlStructure: ControlStructure => + cfgForControlStructure(controlStructure) + case jumpTarget: JumpTarget => + cfgForJumpTarget(jumpTarget) + case ret: Return if withinATryBlock(ret) => + cfgForReturn(ret, inheritFringe = true) + case ret: Return => + cfgForReturn(ret) + case call: Call if call.name == Operators.logicalAnd => + cfgForAndExpression(call) + case call: Call if call.name == Operators.logicalOr => + cfgForOrExpression(call) + case call: Call if call.name == Operators.conditional => + cfgForConditionalExpression(call) + case call: Call if call.dispatchType == DispatchTypes.INLINED => + cfgForInlinedCall(call) + case block: Block if blockMatches(block) => + cfgForChildren(block) + case _: Block => + cfgForChildren(node) ++ cfgForSingleNode(node.asInstanceOf[CfgNode]) + case _: Call | _: FieldIdentifier | _: Identifier | _: Literal | _: Block | _: Unknown => + cfgForChildren(node) ++ cfgForSingleNode(node.asInstanceOf[CfgNode]) + case _ => + cfgForChildren(node) + + private def isLogicalOperator(node: AstNode): Boolean = node match + case call: Call => + call.name == Operators.conditional || call.name == Operators.logicalOr || call + .name == Operators.logicalAnd + case _ => false + + private def isInlinedCall(node: AstNode): Boolean = node match + case call: Call => call.dispatchType == DispatchTypes.INLINED + case _ => false + + /** Only include block nodes that do not describe the entire method body or the bodies of control + * structures or inlined calls or logical operators. + */ + private def blockMatches(block: Block): Boolean = + if block._astIn.hasNext then + val parentNode = block.astParent + parentNode.isMethod || parentNode.isControlStructure || isLogicalOperator( + parentNode + ) || isInlinedCall(parentNode) + else false + + /** A second layer of dispatching for control structures. This could as well be part of `cfgFor` + * and has only been placed into a separate function to increase readability. + */ + protected def cfgForControlStructure(node: ControlStructure): Cfg = + node.controlStructureType match + case ControlStructureTypes.BREAK => + cfgForBreakStatement(node) + case ControlStructureTypes.CONTINUE => + cfgForContinueStatement(node) + case ControlStructureTypes.WHILE => + cfgForWhileStatement(node) + case ControlStructureTypes.DO => + cfgForDoStatement(node) + case ControlStructureTypes.FOR => + cfgForForStatement(node) + case ControlStructureTypes.GOTO => + cfgForGotoStatement(node) + case ControlStructureTypes.IF => + cfgForIfStatement(node) + case ControlStructureTypes.ELSE => + cfgForChildren(node) + case ControlStructureTypes.SWITCH => + cfgForSwitchStatement(node) + case ControlStructureTypes.TRY => + cfgForTryStatement(node) + case ControlStructureTypes.MATCH => + cfgForMatchExpression(node) + case _ => + Cfg.empty + + /** The CFG for a break/continue statements contains only the break/continue statement as a single + * entry node. The fringe is empty, that is, appending another CFG to the break statement will + * not result in the creation of an edge from the break statement to the entry point of the other + * CFG. Labeled breaks are treated like gotos and are added to "jumpsToLabel". + */ + protected def cfgForBreakStatement(node: ControlStructure): Cfg = + node.astChildren.find(_.order == 1) match + case Some(jumpLabel: JumpLabel) => + val labelName = jumpLabel.name + Cfg(entryNode = Option(node), jumpsToLabel = List((node, labelName))) + case Some(literal: Literal) => + // In case we find a literal, it is assumed to be an integer literal which + // indicates how many loop/switch levels the break shall apply to. + val numberOfLevels = Integer.valueOf(literal.code) + Cfg(entryNode = Option(node), breaks = List((node, numberOfLevels))) + case Some(_) => + throw new NotImplementedError( + "Only jump labels and integer literals are currently supported for break statements." ) - end cfgForDoStatement - - /** CFG creation for while statements of the form `while(condition) body1 else body2` where - * body1 and the else block are optional. - */ - protected def cfgForWhileStatement(node: ControlStructure): Cfg = - val conditionCfg = node.condition.headOption.map(cfgFor).getOrElse(Cfg.empty) - val trueCfg = node.whenTrue.headOption.map(cfgFor).getOrElse(Cfg.empty) - val falseCfg = node.whenFalse.headOption.map(cfgFor).getOrElse(Cfg.empty) - - val diffGraphs = edgesFromFringeTo(conditionCfg, trueCfg.entryNode) ++ - edgesFromFringeTo(trueCfg, falseCfg.entryNode) ++ - edgesFromFringeTo(trueCfg, conditionCfg.entryNode) ++ - edges(takeCurrentLevel(trueCfg.continues), conditionCfg.entryNode) - - Cfg - .from(conditionCfg, trueCfg, falseCfg) - .copy( - entryNode = conditionCfg.entryNode, - edges = diffGraphs ++ conditionCfg.edges ++ trueCfg.edges ++ falseCfg.edges, - fringe = - conditionCfg.fringe.withEdgeType(FalseEdge) ++ takeCurrentLevel(trueCfg.breaks) - .map((_, AlwaysEdge)) ++ falseCfg.fringe, - breaks = reduceAndFilterLevel(trueCfg.breaks), - continues = reduceAndFilterLevel(trueCfg.continues) + case None => + Cfg(entryNode = Option(node), breaks = List((node, 1))) + + protected def cfgForContinueStatement(node: ControlStructure): Cfg = + node.astChildren.find(_.order == 1) match + case Some(jumpLabel: JumpLabel) => + val labelName = jumpLabel.name + Cfg(entryNode = Option(node), jumpsToLabel = List((node, labelName))) + case Some(literal: Literal) => + // In case we find a literal, it is assumed to be an integer literal which + // indicates how many loop levels the continue shall apply to. + val numberOfLevels = Integer.valueOf(literal.code) + Cfg(entryNode = Option(node), continues = List((node, numberOfLevels))) + case Some(_) => + throw new NotImplementedError( + "Only jump labels and integer literals are currently supported for continue statements." ) - end cfgForWhileStatement - - /** CFG creation for switch statements of the form `switch { case condition: ... }`. - */ - protected def cfgForSwitchStatement(node: ControlStructure): Cfg = - val conditionCfg = node.condition.headOption.map(cfgFor).getOrElse(Cfg.empty) - val bodyCfg = node.whenTrue.headOption.map(cfgFor).getOrElse(Cfg.empty) - - cfgForSwitchLike(conditionCfg, bodyCfg :: Nil) - - /** CFG creation for if statements of the form `if(condition) body`, optionally followed by - * `else body2`. - */ - protected def cfgForIfStatement(node: ControlStructure): Cfg = - val conditionCfg = node.condition.headOption.map(cfgFor).getOrElse(Cfg.empty) - val trueCfg = node.whenTrue.headOption.map(cfgFor).getOrElse(Cfg.empty) - val falseCfg = node.whenFalse.headOption.map(cfgFor).getOrElse(Cfg.empty) - - val diffGraphs = edgesFromFringeTo(conditionCfg, trueCfg.entryNode) ++ - edgesFromFringeTo(conditionCfg, falseCfg.entryNode) - - Cfg - .from(conditionCfg, trueCfg, falseCfg) - .copy( - entryNode = conditionCfg.entryNode, - edges = diffGraphs ++ conditionCfg.edges ++ trueCfg.edges ++ falseCfg.edges, - fringe = trueCfg.fringe ++ { - if falseCfg.entryNode.isDefined then - falseCfg.fringe - else - conditionCfg.fringe.withEdgeType(FalseEdge) - } - ) - end cfgForIfStatement - - /** CFG creation for try statements of the form `try { tryBody ] catch { catchBody } `, - * optionally followed by `finally { finallyBody }`. - * - * To avoid very large CFGs for try statements, only edges from the last statement in the `try` - * block to each `catch` block (and optionally the `finally` block) are created. The last - * statement in each `catch` block should then have an outgoing edge to the `finally` block if - * it exists (and not to any subsequent catch blocks), or otherwise * be part of the fringe. - * - * By default, the first child of the `TRY` node is treated as the try body, while every - * subsequent node is treated as a `catch`, with no `finally` present. To treat the last child - * of the node as the `finally` block, the `code` field of the `Block` node must be set to - * `finally`. - */ - protected def cfgForTryStatement(node: ControlStructure): Cfg = - val maybeTryBlock = - node.astChildren - .where(_.order(1)) - .where(_.astChildren) // Filter out empty `try` bodies - .headOption - - val tryBodyCfg: Cfg = maybeTryBlock.map(cfgFor).getOrElse(Cfg.empty) - - val catchBodyCfgs: List[Cfg] = - node.astChildren - .where(_.order(2)) - .toList match - case Nil => List(Cfg.empty) - case asts => asts.map(cfgFor) - - val maybeFinallyBodyCfg: List[Cfg] = - node.astChildren - .where(_.order(3)) - .map(cfgFor) - .headOption // Assume there can only be one - .toList - - val tryToCatchEdges = catchBodyCfgs.flatMap { catchBodyCfg => - edgesFromFringeTo(tryBodyCfg, catchBodyCfg.entryNode) - } - - val catchToFinallyEdges = ( - for ( - catchBodyCfg <- catchBodyCfgs; - finallyBodyCfg <- maybeFinallyBodyCfg - ) yield edgesFromFringeTo(catchBodyCfg, finallyBodyCfg.entryNode) - ).flatten - - val tryToFinallyEdges = maybeFinallyBodyCfg.flatMap { cfg => - edgesFromFringeTo(tryBodyCfg, cfg.entryNode) + case None => + Cfg(entryNode = Option(node), continues = List((node, 1))) + + /** Jump targets ("labels") are included in the CFG. As these should be connected to the next + * appended CFG, we specify that the label node is both the entry node and the only node in the + * fringe. This is achieved by calling `cfgForSingleNode` on the label node. Just like for breaks + * and continues, we record labels. We store case/default labels separately from other labels, + * but that is not a relevant implementation detail. + */ + protected def cfgForJumpTarget(n: JumpTarget): Cfg = + val labelName = n.name + val cfg = cfgForSingleNode(n) + if labelName.startsWith("case") || labelName.startsWith("default") then + cfg.copy(caseLabels = List(n)) + else + cfg.copy(labeledNodes = Map(labelName -> n)) + + /** A CFG for a goto statement is one containing the goto node as an entry node and an empty + * fringe. Moreover, we store the goto for dispatching with `withResolvedJumpToLabel` once the + * CFG for the entire method has been calculated. + */ + protected def cfgForGotoStatement(node: ControlStructure): Cfg = + node.astChildren.find(_.order == 1) match + case Some(jumpLabel) => + val labelName = jumpLabel.asInstanceOf[JumpLabel].name + Cfg(entryNode = Option(node), jumpsToLabel = List((node, labelName))) + case None => + // Support for old format where the label name is parsed from the code field. + val target = node.code.split(" ").lastOption.map(x => x.slice(0, x.length - 1)) + target.map(t => + Cfg(entryNode = Some(node), jumpsToLabel = List((node, t))) + ).getOrElse(Cfg.empty) + + /** Return statements may contain expressions as return values, and therefore, the CFG for a + * return statement consists of the CFG for calculation of that expression, appended to a CFG + * containing only the return node, connected with a single edge to the method exit node. The + * fringe is empty. + * + * @param inheritFringe + * indicates if the resulting Cfg object must contain the fringe value of the return value's + * children. + */ + protected def cfgForReturn(actualRet: Return, inheritFringe: Boolean = false): Cfg = + val childrenCfg = cfgForChildren(actualRet) + childrenCfg ++ + Cfg( + entryNode = Option(actualRet), + edges = singleEdge(actualRet, exitNode), + if inheritFringe then childrenCfg.fringe else List() + ) + + /** The right hand side of a logical AND expression is only evaluated if the left hand side is + * true as the entire expression can only be true if both expressions are true. This is encoded + * in the corresponding control flow graph by creating control flow graphs for the left and right + * hand expressions and appending the two, where the fringe edge type of the left CFG is + * `TrueEdge`. + */ + protected def cfgForAndExpression(call: Call): Cfg = + val leftCfg = cfgFor(call.argument(1)) + val rightCfg = cfgFor(call.argument(2)) + val diffGraphs = edgesFromFringeTo( + leftCfg, + rightCfg.entryNode, + TrueEdge + ) ++ leftCfg.edges ++ rightCfg.edges + Cfg + .from(leftCfg, rightCfg) + .copy( + entryNode = leftCfg.entryNode, + edges = diffGraphs, + fringe = leftCfg.fringe ++ rightCfg.fringe + ) ++ cfgForSingleNode(call) + + /** Same construction recipe as for the AND expression, just that the fringe edge type of the left + * CFG is `FalseEdge`. + */ + protected def cfgForOrExpression(call: Call): Cfg = + val leftCfg = cfgFor(call.argument(1)) + val rightCfg = cfgFor(call.argument(2)) + val diffGraphs = edgesFromFringeTo( + leftCfg, + rightCfg.entryNode, + FalseEdge + ) ++ leftCfg.edges ++ rightCfg.edges + Cfg + .from(leftCfg, rightCfg) + .copy( + entryNode = leftCfg.entryNode, + edges = diffGraphs, + fringe = leftCfg.fringe ++ rightCfg.fringe + ) ++ cfgForSingleNode(call) + + /** A conditional expression is of the form `condition ? trueExpr ; falseExpr` where both + * `trueExpr` and `falseExpr` are optional. We create the corresponding CFGs by creating CFGs for + * the three expressions and adding edges between them. The new entry node is the condition entry + * node. + */ + protected def cfgForConditionalExpression(call: Call): Cfg = + val conditionCfg = cfgFor(call.argument(1)) + val trueCfg = call.argumentOption(2).map(cfgFor).getOrElse(Cfg.empty) + val falseCfg = call.argumentOption(3).map(cfgFor).getOrElse(Cfg.empty) + val diffGraphs = edgesFromFringeTo(conditionCfg, trueCfg.entryNode, TrueEdge) ++ + edgesFromFringeTo(conditionCfg, falseCfg.entryNode, FalseEdge) + + val trueFridge = if trueCfg.entryNode.isDefined then + trueCfg.fringe + else + conditionCfg.fringe.withEdgeType(TrueEdge) + val falseFridge = if falseCfg.entryNode.isDefined then + falseCfg.fringe + else + conditionCfg.fringe.withEdgeType(FalseEdge) + + Cfg + .from(conditionCfg, trueCfg, falseCfg) + .copy( + entryNode = conditionCfg.entryNode, + edges = conditionCfg.edges ++ trueCfg.edges ++ falseCfg.edges ++ diffGraphs, + fringe = trueFridge ++ falseFridge + ) ++ cfgForSingleNode(call) + end cfgForConditionalExpression + + /** For macros, the AST contains a CALL node, along with child sub trees for all arguments, and a + * final sub tree that contains the inlined code. The corresponding CFG consists of the CFG for + * the call, an edge to the exit and an edge to the CFG of the inlined code. We choose this + * representation because it allows both queries that use the macro reference as well as queries + * that reference the inline code to be chosen as sources/sinks in data flow queries. + */ + def cfgForInlinedCall(call: Call): Cfg = + val cfgForMacroCall = call.argument.l + .map(cfgFor) + .reduceOption((x, y) => x ++ y) + .getOrElse(Cfg.empty) ++ cfgForSingleNode(call) + val cfgForExpansion = call.astChildren.lastOption.map(cfgFor).getOrElse(Cfg.empty) + val cfg = Cfg + .from(cfgForMacroCall, cfgForExpansion) + .copy( + entryNode = cfgForMacroCall.entryNode, + edges = + cfgForMacroCall.edges ++ cfgForExpansion.edges ++ cfgForExpansion.entryNode.toList + .flatMap(x => singleEdge(call, x)), + fringe = cfgForMacroCall.fringe ++ cfgForExpansion.fringe + ) + cfg + + /** A for statement is of the form `for(initExpr; condition; loopExpr) body` and all four + * components may be empty. The sequence (condition - body - loopExpr) form the inner part of the + * loop and we calculate the corresponding CFG `innerCfg` so that it is no longer relevant which + * of these three actually exist and we still have an entry node for the loop and a fringe. + */ + protected def cfgForForStatement(node: ControlStructure): Cfg = + val children = node.astChildren.l + val nLocals = children.count(_.isLocal) + val initExprCfg = children.find(_.order == nLocals + 1).map(cfgFor).getOrElse(Cfg.empty) + val conditionCfg = children.find(_.order == nLocals + 2).map(cfgFor).getOrElse(Cfg.empty) + val loopExprCfg = children.find(_.order == nLocals + 3).map(cfgFor).getOrElse(Cfg.empty) + val bodyCfg = children.find(_.order == nLocals + 4).map(cfgFor).getOrElse(Cfg.empty) + + val innerCfg = conditionCfg ++ bodyCfg ++ loopExprCfg + val entryNode = (initExprCfg ++ innerCfg).entryNode + + val newEdges = edgesFromFringeTo(initExprCfg, innerCfg.entryNode) ++ + edgesFromFringeTo(innerCfg, innerCfg.entryNode) ++ + edgesFromFringeTo(conditionCfg, bodyCfg.entryNode, TrueEdge) ++ { + if loopExprCfg.entryNode.isDefined then + edges(takeCurrentLevel(bodyCfg.continues), loopExprCfg.entryNode) + else + edges(takeCurrentLevel(bodyCfg.continues), innerCfg.entryNode) } - val diffGraphs = tryToCatchEdges ++ catchToFinallyEdges ++ tryToFinallyEdges - - if maybeTryBlock.isEmpty then - // This case deals with the situation where the try block is empty. In this case, - // no catch block can be executed since nothing can be thrown, but the finally block - // will still be executed. - maybeFinallyBodyCfg.headOption.getOrElse(Cfg.empty) - else - Cfg - .from(Seq(tryBodyCfg) ++ catchBodyCfgs ++ maybeFinallyBodyCfg*) - .copy( - entryNode = tryBodyCfg.entryNode, - edges = - diffGraphs ++ tryBodyCfg.edges ++ catchBodyCfgs.flatMap( - _.edges - ) ++ maybeFinallyBodyCfg.flatMap(_.edges), - fringe = if maybeFinallyBodyCfg.flatMap(_.entryNode).nonEmpty then - maybeFinallyBodyCfg.head.fringe - else - tryBodyCfg.fringe ++ catchBodyCfgs.flatMap(_.fringe) - ) - end if - end cfgForTryStatement - - /** The CFGs for match cases are modeled after PHP match expressions and assumes that a case - * will always consist of one or more JumpTargets followed by a single expression. The CFG also - * assumes an implicit `break` at the end of each match case. - */ - protected def cfgsForMatchCases(body: AstNode): List[Cfg] = - body.astChildren - .foldLeft(List(Cfg.empty)) { - case (currCfg :: prevCfgs, astNode) => - astNode match - case jumpTarget: JumpTarget => - val jumpCfg = cfgForJumpTarget(jumpTarget) - (currCfg ++ jumpCfg) :: prevCfgs - - case node: AstNode => - val nodeCfg = cfgFor(node) - Cfg.empty :: (currCfg ++ nodeCfg) :: prevCfgs - case _ => List.empty - } - .reverse - - /** CFG creation for match expressions of the form `match { case condition: expr ... }` - */ - protected def cfgForMatchExpression(node: ControlStructure): Cfg = - val conditionCfg = node.condition.headOption.map(cfgFor).getOrElse(Cfg.empty) - val bodyCfgs = node.whenTrue.headOption.map(cfgsForMatchCases).getOrElse(Nil) - - cfgForSwitchLike(conditionCfg, bodyCfgs) - - protected def cfgForSwitchLike(conditionCfg: Cfg, bodyCfgs: List[Cfg]): Cfg = - val hasDefaultCase = - bodyCfgs.flatMap(_.caseLabels).exists(x => x.asInstanceOf[JumpTarget].name == "default") - val caseEdges = - edgesToMultiple(conditionCfg.fringe.map(_._1), bodyCfgs.flatMap(_.caseLabels), CaseEdge) - val breakFringe = takeCurrentLevel(bodyCfgs.flatMap(_.breaks)).map((_, AlwaysEdge)) - - Cfg - .from(conditionCfg :: bodyCfgs*) - .copy( - entryNode = conditionCfg.entryNode, - edges = caseEdges ++ conditionCfg.edges ++ bodyCfgs.flatMap(_.edges), - fringe = { - if !hasDefaultCase then conditionCfg.fringe.withEdgeType(FalseEdge) - else Nil - } ++ breakFringe ++ bodyCfgs.flatMap(_.fringe), - caseLabels = List(), - breaks = reduceAndFilterLevel(bodyCfgs.flatMap(_.breaks)), - continues = bodyCfgs.flatMap(_.continues) - ) - end cfgForSwitchLike + Cfg + .from(initExprCfg, conditionCfg, loopExprCfg, bodyCfg) + .copy( + entryNode = entryNode, + edges = newEdges ++ initExprCfg.edges ++ innerCfg.edges, + fringe = conditionCfg.fringe.withEdgeType(FalseEdge) ++ takeCurrentLevel( + bodyCfg.breaks + ).map((_, AlwaysEdge)), + breaks = reduceAndFilterLevel(bodyCfg.breaks), + continues = reduceAndFilterLevel(bodyCfg.continues) + ) + end cfgForForStatement + + /** A Do-Statement is of the form `do body while(condition)` where body may be empty. We again + * first calculate the inner CFG as bodyCfg ++ conditionCfg and then connect edges according to + * the semantics of do-while. + */ + protected def cfgForDoStatement(node: ControlStructure): Cfg = + val bodyCfg = node.astChildren.where(_.order(1)).headOption.map(cfgFor).getOrElse(Cfg.empty) + val conditionCfg = node.condition.headOption.map(cfgFor).getOrElse(Cfg.empty) + val innerCfg = bodyCfg ++ conditionCfg + + val diffGraphs = + edges(takeCurrentLevel(bodyCfg.continues), conditionCfg.entryNode) ++ + edgesFromFringeTo(bodyCfg, conditionCfg.entryNode) ++ + edgesFromFringeTo(conditionCfg, innerCfg.entryNode, TrueEdge) + + Cfg + .from(bodyCfg, conditionCfg, innerCfg) + .copy( + entryNode = if bodyCfg != Cfg.empty then bodyCfg.entryNode + else conditionCfg.entryNode, + edges = diffGraphs ++ bodyCfg.edges ++ conditionCfg.edges, + fringe = conditionCfg.fringe.withEdgeType(FalseEdge) ++ takeCurrentLevel( + bodyCfg.breaks + ).map((_, AlwaysEdge)), + breaks = reduceAndFilterLevel(bodyCfg.breaks), + continues = reduceAndFilterLevel(bodyCfg.continues) + ) + end cfgForDoStatement + + /** CFG creation for while statements of the form `while(condition) body1 else body2` where body1 + * and the else block are optional. + */ + protected def cfgForWhileStatement(node: ControlStructure): Cfg = + val conditionCfg = node.condition.headOption.map(cfgFor).getOrElse(Cfg.empty) + val trueCfg = node.whenTrue.headOption.map(cfgFor).getOrElse(Cfg.empty) + val falseCfg = node.whenFalse.headOption.map(cfgFor).getOrElse(Cfg.empty) + + val diffGraphs = edgesFromFringeTo(conditionCfg, trueCfg.entryNode) ++ + edgesFromFringeTo(trueCfg, falseCfg.entryNode) ++ + edgesFromFringeTo(trueCfg, conditionCfg.entryNode) ++ + edges(takeCurrentLevel(trueCfg.continues), conditionCfg.entryNode) + + Cfg + .from(conditionCfg, trueCfg, falseCfg) + .copy( + entryNode = conditionCfg.entryNode, + edges = diffGraphs ++ conditionCfg.edges ++ trueCfg.edges ++ falseCfg.edges, + fringe = + conditionCfg.fringe.withEdgeType(FalseEdge) ++ takeCurrentLevel(trueCfg.breaks) + .map((_, AlwaysEdge)) ++ falseCfg.fringe, + breaks = reduceAndFilterLevel(trueCfg.breaks), + continues = reduceAndFilterLevel(trueCfg.continues) + ) + end cfgForWhileStatement + + /** CFG creation for switch statements of the form `switch { case condition: ... }`. + */ + protected def cfgForSwitchStatement(node: ControlStructure): Cfg = + val conditionCfg = node.condition.headOption.map(cfgFor).getOrElse(Cfg.empty) + val bodyCfg = node.whenTrue.headOption.map(cfgFor).getOrElse(Cfg.empty) + + cfgForSwitchLike(conditionCfg, bodyCfg :: Nil) + + /** CFG creation for if statements of the form `if(condition) body`, optionally followed by `else + * body2`. + */ + protected def cfgForIfStatement(node: ControlStructure): Cfg = + val conditionCfg = node.condition.headOption.map(cfgFor).getOrElse(Cfg.empty) + val trueCfg = node.whenTrue.headOption.map(cfgFor).getOrElse(Cfg.empty) + val falseCfg = node.whenFalse.headOption.map(cfgFor).getOrElse(Cfg.empty) + + val diffGraphs = edgesFromFringeTo(conditionCfg, trueCfg.entryNode) ++ + edgesFromFringeTo(conditionCfg, falseCfg.entryNode) + + Cfg + .from(conditionCfg, trueCfg, falseCfg) + .copy( + entryNode = conditionCfg.entryNode, + edges = diffGraphs ++ conditionCfg.edges ++ trueCfg.edges ++ falseCfg.edges, + fringe = trueCfg.fringe ++ { + if falseCfg.entryNode.isDefined then + falseCfg.fringe + else + conditionCfg.fringe.withEdgeType(FalseEdge) + } + ) + end cfgForIfStatement + + /** CFG creation for try statements of the form `try { tryBody ] catch { catchBody } `, optionally + * followed by `finally { finallyBody }`. + * + * To avoid very large CFGs for try statements, only edges from the last statement in the `try` + * block to each `catch` block (and optionally the `finally` block) are created. The last + * statement in each `catch` block should then have an outgoing edge to the `finally` block if it + * exists (and not to any subsequent catch blocks), or otherwise * be part of the fringe. + * + * By default, the first child of the `TRY` node is treated as the try body, while every + * subsequent node is treated as a `catch`, with no `finally` present. To treat the last child of + * the node as the `finally` block, the `code` field of the `Block` node must be set to + * `finally`. + */ + protected def cfgForTryStatement(node: ControlStructure): Cfg = + val maybeTryBlock = + node.astChildren + .where(_.order(1)) + .where(_.astChildren) // Filter out empty `try` bodies + .headOption + + val tryBodyCfg: Cfg = maybeTryBlock.map(cfgFor).getOrElse(Cfg.empty) + + val catchBodyCfgs: List[Cfg] = + node.astChildren + .where(_.order(2)) + .toList match + case Nil => List(Cfg.empty) + case asts => asts.map(cfgFor) + + val maybeFinallyBodyCfg: List[Cfg] = + node.astChildren + .where(_.order(3)) + .map(cfgFor) + .headOption // Assume there can only be one + .toList + + val tryToCatchEdges = catchBodyCfgs.flatMap { catchBodyCfg => + edgesFromFringeTo(tryBodyCfg, catchBodyCfg.entryNode) + } + + val catchToFinallyEdges = ( + for ( + catchBodyCfg <- catchBodyCfgs; + finallyBodyCfg <- maybeFinallyBodyCfg + ) yield edgesFromFringeTo(catchBodyCfg, finallyBodyCfg.entryNode) + ).flatten + + val tryToFinallyEdges = maybeFinallyBodyCfg.flatMap { cfg => + edgesFromFringeTo(tryBodyCfg, cfg.entryNode) + } + + val diffGraphs = tryToCatchEdges ++ catchToFinallyEdges ++ tryToFinallyEdges + + if maybeTryBlock.isEmpty then + // This case deals with the situation where the try block is empty. In this case, + // no catch block can be executed since nothing can be thrown, but the finally block + // will still be executed. + maybeFinallyBodyCfg.headOption.getOrElse(Cfg.empty) + else + Cfg + .from(Seq(tryBodyCfg) ++ catchBodyCfgs ++ maybeFinallyBodyCfg*) + .copy( + entryNode = tryBodyCfg.entryNode, + edges = + diffGraphs ++ tryBodyCfg.edges ++ catchBodyCfgs.flatMap( + _.edges + ) ++ maybeFinallyBodyCfg.flatMap(_.edges), + fringe = if maybeFinallyBodyCfg.flatMap(_.entryNode).nonEmpty then + maybeFinallyBodyCfg.head.fringe + else + tryBodyCfg.fringe ++ catchBodyCfgs.flatMap(_.fringe) + ) + end if + end cfgForTryStatement + + /** The CFGs for match cases are modeled after PHP match expressions and assumes that a case will + * always consist of one or more JumpTargets followed by a single expression. The CFG also + * assumes an implicit `break` at the end of each match case. + */ + protected def cfgsForMatchCases(body: AstNode): List[Cfg] = + body.astChildren + .foldLeft(List(Cfg.empty)) { + case (currCfg :: prevCfgs, astNode) => + astNode match + case jumpTarget: JumpTarget => + val jumpCfg = cfgForJumpTarget(jumpTarget) + (currCfg ++ jumpCfg) :: prevCfgs + + case node: AstNode => + val nodeCfg = cfgFor(node) + Cfg.empty :: (currCfg ++ nodeCfg) :: prevCfgs + case _ => List.empty + } + .reverse + + /** CFG creation for match expressions of the form `match { case condition: expr ... }` + */ + protected def cfgForMatchExpression(node: ControlStructure): Cfg = + val conditionCfg = node.condition.headOption.map(cfgFor).getOrElse(Cfg.empty) + val bodyCfgs = node.whenTrue.headOption.map(cfgsForMatchCases).getOrElse(Nil) + + cfgForSwitchLike(conditionCfg, bodyCfgs) + + protected def cfgForSwitchLike(conditionCfg: Cfg, bodyCfgs: List[Cfg]): Cfg = + val hasDefaultCase = + bodyCfgs.flatMap(_.caseLabels).exists(x => x.asInstanceOf[JumpTarget].name == "default") + val caseEdges = + edgesToMultiple(conditionCfg.fringe.map(_._1), bodyCfgs.flatMap(_.caseLabels), CaseEdge) + val breakFringe = takeCurrentLevel(bodyCfgs.flatMap(_.breaks)).map((_, AlwaysEdge)) + + Cfg + .from(conditionCfg :: bodyCfgs*) + .copy( + entryNode = conditionCfg.entryNode, + edges = caseEdges ++ conditionCfg.edges ++ bodyCfgs.flatMap(_.edges), + fringe = { + if !hasDefaultCase then conditionCfg.fringe.withEdgeType(FalseEdge) + else Nil + } ++ breakFringe ++ bodyCfgs.flatMap(_.fringe), + caseLabels = List(), + breaks = reduceAndFilterLevel(bodyCfgs.flatMap(_.breaks)), + continues = bodyCfgs.flatMap(_.continues) + ) + end cfgForSwitchLike end CfgCreator object CfgCreator: - implicit class FringeWrapper(fringe: List[(CfgNode, CfgEdgeType)]): - def withEdgeType(edgeType: CfgEdgeType): List[(CfgNode, CfgEdgeType)] = - fringe.map { case (x, _) => (x, edgeType) } + implicit class FringeWrapper(fringe: List[(CfgNode, CfgEdgeType)]): + def withEdgeType(edgeType: CfgEdgeType): List[(CfgNode, CfgEdgeType)] = + fringe.map { case (x, _) => (x, edgeType) } diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/controlflow/cfgdominator/CfgAdapter.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/controlflow/cfgdominator/CfgAdapter.scala index 59bcb21e..7c791799 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/controlflow/cfgdominator/CfgAdapter.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/controlflow/cfgdominator/CfgAdapter.scala @@ -1,5 +1,5 @@ package io.appthreat.x2cpg.passes.controlflow.cfgdominator trait CfgAdapter[Node]: - def successors(node: Node): IterableOnce[Node] - def predecessors(node: Node): IterableOnce[Node] + def successors(node: Node): IterableOnce[Node] + def predecessors(node: Node): IterableOnce[Node] diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/controlflow/cfgdominator/CfgDominator.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/controlflow/cfgdominator/CfgDominator.scala index 5b446ef0..bd4d5068 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/controlflow/cfgdominator/CfgDominator.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/controlflow/cfgdominator/CfgDominator.scala @@ -5,83 +5,83 @@ import scala.collection.mutable class CfgDominator[NodeType](adapter: CfgAdapter[NodeType]): - /** Calculates the immediate dominators of all CFG nodes reachable from cfgEntry. Since the - * cfgEntry does not have an immediate dominator, it has no entry in the return map. - * - * The algorithm is from: "A Simple, Fast Dominance Algorithm" from "Keith D. Cooper, Timothy - * J. Harvey, and Ken Kennedy". - */ - def calculate(cfgEntry: NodeType): mutable.LinkedHashMap[NodeType, NodeType] = - val UNDEFINED = -1 - def expand(x: NodeType) = adapter.successors(x).iterator - def expandBack(x: NodeType) = adapter.predecessors(x).iterator + /** Calculates the immediate dominators of all CFG nodes reachable from cfgEntry. Since the + * cfgEntry does not have an immediate dominator, it has no entry in the return map. + * + * The algorithm is from: "A Simple, Fast Dominance Algorithm" from "Keith D. Cooper, Timothy J. + * Harvey, and Ken Kennedy". + */ + def calculate(cfgEntry: NodeType): mutable.LinkedHashMap[NodeType, NodeType] = + val UNDEFINED = -1 + def expand(x: NodeType) = adapter.successors(x).iterator + def expandBack(x: NodeType) = adapter.predecessors(x).iterator - val postOrderNumbering = NodeOrdering.postOrderNumbering(cfgEntry, expand) - val nodesInReversePostOrder = - NodeOrdering.reverseNodeList(postOrderNumbering.toList).filterNot(_ == cfgEntry) - // Index of each node into dominators array. - val indexOf = postOrderNumbering.withDefaultValue(UNDEFINED) - // We use withDefault because unreachable/dead - // code nodes are not numbered but may be touched - // as predecessors of reachable nodes. + val postOrderNumbering = NodeOrdering.postOrderNumbering(cfgEntry, expand) + val nodesInReversePostOrder = + NodeOrdering.reverseNodeList(postOrderNumbering.toList).filterNot(_ == cfgEntry) + // Index of each node into dominators array. + val indexOf = postOrderNumbering.withDefaultValue(UNDEFINED) + // We use withDefault because unreachable/dead + // code nodes are not numbered but may be touched + // as predecessors of reachable nodes. - val dominators = Array.fill(indexOf.size)(UNDEFINED) - dominators(indexOf(cfgEntry)) = indexOf(cfgEntry) + val dominators = Array.fill(indexOf.size)(UNDEFINED) + dominators(indexOf(cfgEntry)) = indexOf(cfgEntry) - /* Retrieve index of immediate dominator for node with given index. If the index is `UNDEFINED`, UNDEFINED is - * returned. */ - def safeDominators(index: Int): Int = - if index != UNDEFINED then - dominators(index) - else - UNDEFINED + /* Retrieve index of immediate dominator for node with given index. If the index is `UNDEFINED`, UNDEFINED is + * returned. */ + def safeDominators(index: Int): Int = + if index != UNDEFINED then + dominators(index) + else + UNDEFINED - var changed = true - while changed do - changed = false - nodesInReversePostOrder.foreach { node => - val firstNotUndefinedPred = expandBack(node).find { predecessor => - safeDominators(indexOf(predecessor)) != UNDEFINED - }.get + var changed = true + while changed do + changed = false + nodesInReversePostOrder.foreach { node => + val firstNotUndefinedPred = expandBack(node).find { predecessor => + safeDominators(indexOf(predecessor)) != UNDEFINED + }.get - var newImmediateDominator = indexOf(firstNotUndefinedPred) - expandBack(node).foreach { predecessor => - val predecessorIndex = indexOf(predecessor) - if safeDominators(predecessorIndex) != UNDEFINED then - newImmediateDominator = - intersect(dominators, predecessorIndex, newImmediateDominator) - } + var newImmediateDominator = indexOf(firstNotUndefinedPred) + expandBack(node).foreach { predecessor => + val predecessorIndex = indexOf(predecessor) + if safeDominators(predecessorIndex) != UNDEFINED then + newImmediateDominator = + intersect(dominators, predecessorIndex, newImmediateDominator) + } - val nodeIndex = indexOf(node) - if dominators(nodeIndex) != newImmediateDominator then - dominators(nodeIndex) = newImmediateDominator - changed = true - } - end while + val nodeIndex = indexOf(node) + if dominators(nodeIndex) != newImmediateDominator then + dominators(nodeIndex) = newImmediateDominator + changed = true + } + end while - val postOrderNumberingToNode = postOrderNumbering.map { case (node, index) => - (index, node) - } + val postOrderNumberingToNode = postOrderNumbering.map { case (node, index) => + (index, node) + } - postOrderNumbering.collect { - case (node, index) if node != cfgEntry => - val immediateDominatorIndex = dominators(index) - (node, postOrderNumberingToNode(immediateDominatorIndex)) - } - end calculate + postOrderNumbering.collect { + case (node, index) if node != cfgEntry => + val immediateDominatorIndex = dominators(index) + (node, postOrderNumberingToNode(immediateDominatorIndex)) + } + end calculate - private def intersect( - dominators: Array[Int], - immediateDomIndex1: Int, - immediateDomIndex2: Int - ): Int = - var finger1 = immediateDomIndex1 - var finger2 = immediateDomIndex2 + private def intersect( + dominators: Array[Int], + immediateDomIndex1: Int, + immediateDomIndex2: Int + ): Int = + var finger1 = immediateDomIndex1 + var finger2 = immediateDomIndex2 - while finger1 != finger2 do - while finger1 < finger2 do - finger1 = dominators(finger1) - while finger2 < finger1 do - finger2 = dominators(finger2) - finger1 + while finger1 != finger2 do + while finger1 < finger2 do + finger1 = dominators(finger1) + while finger2 < finger1 do + finger2 = dominators(finger2) + finger1 end CfgDominator diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/controlflow/cfgdominator/CfgDominatorFrontier.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/controlflow/cfgdominator/CfgDominatorFrontier.scala index 434d0b93..90b7218b 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/controlflow/cfgdominator/CfgDominatorFrontier.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/controlflow/cfgdominator/CfgDominatorFrontier.scala @@ -13,29 +13,29 @@ class CfgDominatorFrontier[NodeType]( domTreeAdapter: DomTreeAdapter[NodeType] ): - private def doms(x: NodeType): Option[NodeType] = domTreeAdapter.immediateDominator(x) - private def pred(x: NodeType): Seq[NodeType] = cfgAdapter.predecessors(x).iterator.to(Seq) - private def onlyJoinNodes(x: NodeType): Option[(NodeType, Seq[NodeType])] = - Option(pred(x)).filter(_.size > 1).map(p => (x, p)) - private def withIDom(x: NodeType, preds: Seq[NodeType]) = - doms(x).map(i => (x, preds, i)) + private def doms(x: NodeType): Option[NodeType] = domTreeAdapter.immediateDominator(x) + private def pred(x: NodeType): Seq[NodeType] = cfgAdapter.predecessors(x).iterator.to(Seq) + private def onlyJoinNodes(x: NodeType): Option[(NodeType, Seq[NodeType])] = + Option(pred(x)).filter(_.size > 1).map(p => (x, p)) + private def withIDom(x: NodeType, preds: Seq[NodeType]) = + doms(x).map(i => (x, preds, i)) - def calculate(cfgNodes: Seq[NodeType]): mutable.Map[NodeType, mutable.Set[NodeType]] = - val domFrontier = mutable.Map.empty[NodeType, mutable.Set[NodeType]] + def calculate(cfgNodes: Seq[NodeType]): mutable.Map[NodeType, mutable.Set[NodeType]] = + val domFrontier = mutable.Map.empty[NodeType, mutable.Set[NodeType]] - for - cfgNode <- cfgNodes - (nodeType, joinNodes) <- onlyJoinNodes(cfgNode) - (joinNode, preds, joinNodeIDom) <- withIDom(nodeType, joinNodes) - do - preds.foreach { p => - var currentPred = Option(p) - while currentPred.isDefined && currentPred.get != joinNodeIDom do - val frontierNodes = - domFrontier.getOrElseUpdate(currentPred.get, mutable.Set.empty) - frontierNodes.add(joinNode) - currentPred = doms(currentPred.get) - } + for + cfgNode <- cfgNodes + (nodeType, joinNodes) <- onlyJoinNodes(cfgNode) + (joinNode, preds, joinNodeIDom) <- withIDom(nodeType, joinNodes) + do + preds.foreach { p => + var currentPred = Option(p) + while currentPred.isDefined && currentPred.get != joinNodeIDom do + val frontierNodes = + domFrontier.getOrElseUpdate(currentPred.get, mutable.Set.empty) + frontierNodes.add(joinNode) + currentPred = doms(currentPred.get) + } - domFrontier + domFrontier end CfgDominatorFrontier diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/controlflow/cfgdominator/CfgDominatorPass.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/controlflow/cfgdominator/CfgDominatorPass.scala index e424b989..be05994e 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/controlflow/cfgdominator/CfgDominatorPass.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/controlflow/cfgdominator/CfgDominatorPass.scala @@ -11,34 +11,34 @@ import scala.collection.mutable /** This pass has no prerequisites. */ class CfgDominatorPass(cpg: Cpg) extends ForkJoinParallelCpgPass[Method](cpg): - override def generateParts(): Array[Method] = cpg.method.toArray - - override def runOnPart(dstGraph: DiffGraphBuilder, method: Method): Unit = - val cfgAdapter = new CpgCfgAdapter() - val dominatorCalculator = new CfgDominator(cfgAdapter) - - val reverseCfgAdapter = new ReverseCpgCfgAdapter() - val postDominatorCalculator = new CfgDominator(reverseCfgAdapter) - - val cfgNodeToImmediateDominator = dominatorCalculator.calculate(method) - addDomTreeEdges(dstGraph, cfgNodeToImmediateDominator) - - val cfgNodeToPostImmediateDominator = postDominatorCalculator.calculate(method.methodReturn) - addPostDomTreeEdges(dstGraph, cfgNodeToPostImmediateDominator) - - private def addDomTreeEdges( - dstGraph: DiffGraphBuilder, - cfgNodeToImmediateDominator: mutable.LinkedHashMap[StoredNode, StoredNode] - ): Unit = - cfgNodeToImmediateDominator.foreach { case (node, immediateDominator) => - dstGraph.addEdge(immediateDominator, node, EdgeTypes.DOMINATE) - } - - private def addPostDomTreeEdges( - dstGraph: DiffGraphBuilder, - cfgNodeToPostImmediateDominator: mutable.LinkedHashMap[StoredNode, StoredNode] - ): Unit = - cfgNodeToPostImmediateDominator.foreach { case (node, immediatePostDominator) => - dstGraph.addEdge(immediatePostDominator, node, EdgeTypes.POST_DOMINATE) - } + override def generateParts(): Array[Method] = cpg.method.toArray + + override def runOnPart(dstGraph: DiffGraphBuilder, method: Method): Unit = + val cfgAdapter = new CpgCfgAdapter() + val dominatorCalculator = new CfgDominator(cfgAdapter) + + val reverseCfgAdapter = new ReverseCpgCfgAdapter() + val postDominatorCalculator = new CfgDominator(reverseCfgAdapter) + + val cfgNodeToImmediateDominator = dominatorCalculator.calculate(method) + addDomTreeEdges(dstGraph, cfgNodeToImmediateDominator) + + val cfgNodeToPostImmediateDominator = postDominatorCalculator.calculate(method.methodReturn) + addPostDomTreeEdges(dstGraph, cfgNodeToPostImmediateDominator) + + private def addDomTreeEdges( + dstGraph: DiffGraphBuilder, + cfgNodeToImmediateDominator: mutable.LinkedHashMap[StoredNode, StoredNode] + ): Unit = + cfgNodeToImmediateDominator.foreach { case (node, immediateDominator) => + dstGraph.addEdge(immediateDominator, node, EdgeTypes.DOMINATE) + } + + private def addPostDomTreeEdges( + dstGraph: DiffGraphBuilder, + cfgNodeToPostImmediateDominator: mutable.LinkedHashMap[StoredNode, StoredNode] + ): Unit = + cfgNodeToPostImmediateDominator.foreach { case (node, immediatePostDominator) => + dstGraph.addEdge(immediatePostDominator, node, EdgeTypes.POST_DOMINATE) + } end CfgDominatorPass diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/controlflow/cfgdominator/CpgCfgAdapter.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/controlflow/cfgdominator/CpgCfgAdapter.scala index f880b578..ee2ecce1 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/controlflow/cfgdominator/CpgCfgAdapter.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/controlflow/cfgdominator/CpgCfgAdapter.scala @@ -4,8 +4,8 @@ import io.shiftleft.codepropertygraph.generated.nodes.StoredNode class CpgCfgAdapter extends CfgAdapter[StoredNode]: - override def successors(node: StoredNode): IterableOnce[StoredNode] = - node._cfgOut + override def successors(node: StoredNode): IterableOnce[StoredNode] = + node._cfgOut - override def predecessors(node: StoredNode): IterableOnce[StoredNode] = - node._cfgIn + override def predecessors(node: StoredNode): IterableOnce[StoredNode] = + node._cfgIn diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/controlflow/cfgdominator/DomTreeAdapter.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/controlflow/cfgdominator/DomTreeAdapter.scala index f50385b0..323db184 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/controlflow/cfgdominator/DomTreeAdapter.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/controlflow/cfgdominator/DomTreeAdapter.scala @@ -2,9 +2,9 @@ package io.appthreat.x2cpg.passes.controlflow.cfgdominator trait DomTreeAdapter[Node]: - /** Returns the immediate dominator of a cfgNode. The returned value can be None if cfgNode was - * the cfg entry node while calculating the dominator relation or if cfgNode is dead code. In - * the post dominator case "dead code" means code which does lead to the normal method exit. An - * example would be a thrown excpetion. - */ - def immediateDominator(cfgNode: Node): Option[Node] + /** Returns the immediate dominator of a cfgNode. The returned value can be None if cfgNode was + * the cfg entry node while calculating the dominator relation or if cfgNode is dead code. In the + * post dominator case "dead code" means code which does lead to the normal method exit. An + * example would be a thrown excpetion. + */ + def immediateDominator(cfgNode: Node): Option[Node] diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/controlflow/cfgdominator/ReverseCpgCfgAdapter.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/controlflow/cfgdominator/ReverseCpgCfgAdapter.scala index 7505aed4..66270ad9 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/controlflow/cfgdominator/ReverseCpgCfgAdapter.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/controlflow/cfgdominator/ReverseCpgCfgAdapter.scala @@ -4,8 +4,8 @@ import io.shiftleft.codepropertygraph.generated.nodes.StoredNode class ReverseCpgCfgAdapter extends CfgAdapter[StoredNode]: - override def successors(node: StoredNode): IterableOnce[StoredNode] = - node._cfgIn + override def successors(node: StoredNode): IterableOnce[StoredNode] = + node._cfgIn - override def predecessors(node: StoredNode): IterableOnce[StoredNode] = - node._cfgOut + override def predecessors(node: StoredNode): IterableOnce[StoredNode] = + node._cfgOut diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/controlflow/codepencegraph/CdgPass.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/controlflow/codepencegraph/CdgPass.scala index 59a28f0e..89b98567 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/controlflow/codepencegraph/CdgPass.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/controlflow/codepencegraph/CdgPass.scala @@ -23,44 +23,44 @@ import org.slf4j.{Logger, LoggerFactory} /** This pass has ContainsEdgePass and CfgDominatorPass as prerequisites. */ class CdgPass(cpg: Cpg) extends ForkJoinParallelCpgPass[Method](cpg): - import CdgPass.logger + import CdgPass.logger - override def generateParts(): Array[Method] = cpg.method.toArray + override def generateParts(): Array[Method] = cpg.method.toArray - override def runOnPart(dstGraph: DiffGraphBuilder, method: Method): Unit = + override def runOnPart(dstGraph: DiffGraphBuilder, method: Method): Unit = - val dominanceFrontier = - new CfgDominatorFrontier(new ReverseCpgCfgAdapter, new CpgPostDomTreeAdapter) + val dominanceFrontier = + new CfgDominatorFrontier(new ReverseCpgCfgAdapter, new CpgPostDomTreeAdapter) - val cfgNodes = method._containsOut.toList - val postDomFrontiers = dominanceFrontier.calculate(method :: cfgNodes) + val cfgNodes = method._containsOut.toList + val postDomFrontiers = dominanceFrontier.calculate(method :: cfgNodes) - postDomFrontiers.foreach { case (node, postDomFrontierNodes) => - postDomFrontierNodes.foreach { - case postDomFrontierNode @ (_: Literal | _: Identifier | _: Call | _: MethodRef | _: Unknown | - _: ControlStructure | _: JumpTarget) => - dstGraph.addEdge(postDomFrontierNode, node, EdgeTypes.CDG) - case postDomFrontierNode => - val nodeLabel = postDomFrontierNode.label - val containsIn = postDomFrontierNode._containsIn - if containsIn == null || !containsIn.hasNext then - logger.debug( - s"Found CDG edge starting at $nodeLabel node. This is most likely caused by an invalid CFG." - ) - else - val method = containsIn.next() - logger.debug( - s"Found CDG edge starting at $nodeLabel node. This is most likely caused by an invalid CFG." + - s" Method: ${method match - case m: Method => m.fullName; - case other => other.label - }" + - s" number of outgoing CFG edges from $nodeLabel node: ${postDomFrontierNode._cfgOut.size}" - ) - } + postDomFrontiers.foreach { case (node, postDomFrontierNodes) => + postDomFrontierNodes.foreach { + case postDomFrontierNode @ (_: Literal | _: Identifier | _: Call | _: MethodRef | _: Unknown | + _: ControlStructure | _: JumpTarget) => + dstGraph.addEdge(postDomFrontierNode, node, EdgeTypes.CDG) + case postDomFrontierNode => + val nodeLabel = postDomFrontierNode.label + val containsIn = postDomFrontierNode._containsIn + if containsIn == null || !containsIn.hasNext then + logger.debug( + s"Found CDG edge starting at $nodeLabel node. This is most likely caused by an invalid CFG." + ) + else + val method = containsIn.next() + logger.debug( + s"Found CDG edge starting at $nodeLabel node. This is most likely caused by an invalid CFG." + + s" Method: ${method match + case m: Method => m.fullName; + case other => other.label + }" + + s" number of outgoing CFG edges from $nodeLabel node: ${postDomFrontierNode._cfgOut.size}" + ) } - end runOnPart + } + end runOnPart end CdgPass object CdgPass: - private val logger: Logger = LoggerFactory.getLogger(classOf[CdgPass]) + private val logger: Logger = LoggerFactory.getLogger(classOf[CdgPass]) diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/controlflow/codepencegraph/CpgPostDomTreeAdapter.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/controlflow/codepencegraph/CpgPostDomTreeAdapter.scala index 873a40ef..3b2f470b 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/controlflow/codepencegraph/CpgPostDomTreeAdapter.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/controlflow/codepencegraph/CpgPostDomTreeAdapter.scala @@ -5,5 +5,5 @@ import io.appthreat.x2cpg.passes.controlflow.cfgdominator.DomTreeAdapter class CpgPostDomTreeAdapter extends DomTreeAdapter[StoredNode]: - override def immediateDominator(cfgNode: StoredNode): Option[StoredNode] = - cfgNode._postDominateIn.nextOption() + override def immediateDominator(cfgNode: StoredNode): Option[StoredNode] = + cfgNode._postDominateIn.nextOption() diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/frontend/Dereference.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/frontend/Dereference.scala index 0efed058..2076335a 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/frontend/Dereference.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/frontend/Dereference.scala @@ -6,22 +6,22 @@ import io.shiftleft.semanticcpg.language.* object Dereference: - def apply(cpg: Cpg): Dereference = cpg.metaData.language.headOption match - case Some(Languages.NEWC) => CDereference() - case _ => DefaultDereference() + def apply(cpg: Cpg): Dereference = cpg.metaData.language.headOption match + case Some(Languages.NEWC) => CDereference() + case _ => DefaultDereference() sealed trait Dereference: - def dereferenceTypeFullName(fullName: String): String + def dereferenceTypeFullName(fullName: String): String case class CDereference() extends Dereference: - /** Types from C/C++ can be annotated with * to indicate being a reference. As our CPG schema - * currently lacks a separate field for that information the * is part of the type full name - * and needs to be removed when linking. - */ - override def dereferenceTypeFullName(fullName: String): String = fullName.replace("*", "") + /** Types from C/C++ can be annotated with * to indicate being a reference. As our CPG schema + * currently lacks a separate field for that information the * is part of the type full name and + * needs to be removed when linking. + */ + override def dereferenceTypeFullName(fullName: String): String = fullName.replace("*", "") case class DefaultDereference() extends Dereference: - override def dereferenceTypeFullName(fullName: String): String = fullName + override def dereferenceTypeFullName(fullName: String): String = fullName diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/frontend/MetaDataPass.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/frontend/MetaDataPass.scala index 7ba7a831..48a03b1a 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/frontend/MetaDataPass.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/frontend/MetaDataPass.scala @@ -10,28 +10,28 @@ import io.shiftleft.semanticcpg.language.types.structure.{FileTraversal, Namespa * NamespaceBlock for anything that cannot be assigned to any other namespace. */ class MetaDataPass(cpg: Cpg, language: String, root: String) extends CpgPass(cpg): - override def run(diffGraph: DiffGraphBuilder): Unit = - def addMetaDataNode(diffGraph: DiffGraphBuilder): Unit = - val absolutePathToRoot = File(root).path.toAbsolutePath.toString - val metaNode = NewMetaData().language(language).root(absolutePathToRoot).version("0.1") - diffGraph.addNode(metaNode) + override def run(diffGraph: DiffGraphBuilder): Unit = + def addMetaDataNode(diffGraph: DiffGraphBuilder): Unit = + val absolutePathToRoot = File(root).path.toAbsolutePath.toString + val metaNode = NewMetaData().language(language).root(absolutePathToRoot).version("0.1") + diffGraph.addNode(metaNode) - def addAnyNamespaceBlock(diffGraph: DiffGraphBuilder): Unit = - val node = NewNamespaceBlock() - .name(NamespaceTraversal.globalNamespaceName) - .fullName(MetaDataPass.getGlobalNamespaceBlockFullName(None)) - .filename(FileTraversal.UNKNOWN) - .order(1) - diffGraph.addNode(node) + def addAnyNamespaceBlock(diffGraph: DiffGraphBuilder): Unit = + val node = NewNamespaceBlock() + .name(NamespaceTraversal.globalNamespaceName) + .fullName(MetaDataPass.getGlobalNamespaceBlockFullName(None)) + .filename(FileTraversal.UNKNOWN) + .order(1) + diffGraph.addNode(node) - addMetaDataNode(diffGraph) - addAnyNamespaceBlock(diffGraph) + addMetaDataNode(diffGraph) + addAnyNamespaceBlock(diffGraph) object MetaDataPass: - def getGlobalNamespaceBlockFullName(fileNameOption: Option[String]): String = - fileNameOption match - case Some(fileName) => - s"$fileName:${NamespaceTraversal.globalNamespaceName}" - case None => - NamespaceTraversal.globalNamespaceName + def getGlobalNamespaceBlockFullName(fileNameOption: Option[String]): String = + fileNameOption match + case Some(fileName) => + s"$fileName:${NamespaceTraversal.globalNamespaceName}" + case None => + NamespaceTraversal.globalNamespaceName diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/frontend/SymbolTable.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/frontend/SymbolTable.scala index a279ec5f..ffe292c2 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/frontend/SymbolTable.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/frontend/SymbolTable.scala @@ -11,41 +11,41 @@ import scala.collection.concurrent.TrieMap */ abstract class SBKey(val identifier: String): - /** Convenience methods to convert a node to a [[SBKey]]. - * - * @param node - * the node to convert. - * @return - * the corresponding [[SBKey]] if the node is supported as a key variable - */ - def fromNode(node: AstNode): Option[SBKey] + /** Convenience methods to convert a node to a [[SBKey]]. + * + * @param node + * the node to convert. + * @return + * the corresponding [[SBKey]] if the node is supported as a key variable + */ + def fromNode(node: AstNode): Option[SBKey] object SBKey: - protected val logger: Logger = LoggerFactory.getLogger(getClass) - def fromNodeToLocalKey(node: AstNode): Option[LocalKey] = - Option(node match - case n: Identifier => LocalVar(n.name) - case n: Local => LocalVar(n.name) - case n: Call => - CallAlias( - n.name, - n.argument.collectFirst { case x: Identifier if x.argumentIndex == 0 => x.name } - ) - case n: Method => CallAlias(n.name, Option("this")) - case n: MethodRef => CallAlias(n.code) - case n: FieldIdentifier => LocalVar(n.canonicalName) - case n: MethodParameterIn => LocalVar(n.name) - case _ => - logger.debug( - s"Local node of type ${node.label} is not supported in the type recovery pass." - ); null - ) + protected val logger: Logger = LoggerFactory.getLogger(getClass) + def fromNodeToLocalKey(node: AstNode): Option[LocalKey] = + Option(node match + case n: Identifier => LocalVar(n.name) + case n: Local => LocalVar(n.name) + case n: Call => + CallAlias( + n.name, + n.argument.collectFirst { case x: Identifier if x.argumentIndex == 0 => x.name } + ) + case n: Method => CallAlias(n.name, Option("this")) + case n: MethodRef => CallAlias(n.code) + case n: FieldIdentifier => LocalVar(n.canonicalName) + case n: MethodParameterIn => LocalVar(n.name) + case _ => + logger.debug( + s"Local node of type ${node.label} is not supported in the type recovery pass." + ); null + ) end SBKey /** Represents an identifier of some AST node at an intraprocedural scope. */ sealed class LocalKey(identifier: String) extends SBKey(identifier): - override def fromNode(node: AstNode): Option[SBKey] = SBKey.fromNodeToLocalKey(node) + override def fromNode(node: AstNode): Option[SBKey] = SBKey.fromNodeToLocalKey(node) /** A variable that holds data within an intraprocedural scope. */ @@ -70,80 +70,80 @@ case class CallAlias(override val identifier: String, receiverName: Option[Strin */ class SymbolTable[K <: SBKey](val keyFromNode: AstNode => Option[K]): - private val table = TrieMap.empty[K, Set[String]] + private val table = TrieMap.empty[K, Set[String]] - /** The set limit is to bound the set of possible types, since by using dummy types we could - * have an unbounded number of permutations of various access paths - */ - private val setLimit = 10 + /** The set limit is to bound the set of possible types, since by using dummy types we could have + * an unbounded number of permutations of various access paths + */ + private val setLimit = 10 - private def coalesce(oldEntries: Set[String], newEntries: Set[String]): Set[String] = - val allTypes = - (oldEntries ++ newEntries).toSeq // convert to ordered set to make `take` work predictably - val (dummies, noDummies) = allTypes.partition(XTypeRecovery.isDummyType) - (noDummies ++ dummies).take(setLimit).toSet + private def coalesce(oldEntries: Set[String], newEntries: Set[String]): Set[String] = + val allTypes = + (oldEntries ++ newEntries).toSeq // convert to ordered set to make `take` work predictably + val (dummies, noDummies) = allTypes.partition(XTypeRecovery.isDummyType) + (noDummies ++ dummies).take(setLimit).toSet - def apply(sbKey: K): Set[String] = table(sbKey) + def apply(sbKey: K): Set[String] = table(sbKey) - def apply(node: AstNode): Set[String] = - keyFromNode(node) match - case Some(key) => table(key) - case None => Set.empty + def apply(node: AstNode): Set[String] = + keyFromNode(node) match + case Some(key) => table(key) + case None => Set.empty - def from(sb: IterableOnce[(K, Set[String])]): SymbolTable[K] = - table.addAll(sb); this + def from(sb: IterableOnce[(K, Set[String])]): SymbolTable[K] = + table.addAll(sb); this - def put(sbKey: K, typeFullNames: Set[String]): Set[String] = - if typeFullNames.nonEmpty then - val newEntry = coalesce(Set.empty, typeFullNames) - table.put(sbKey, newEntry) - newEntry - else - Set.empty + def put(sbKey: K, typeFullNames: Set[String]): Set[String] = + if typeFullNames.nonEmpty then + val newEntry = coalesce(Set.empty, typeFullNames) + table.put(sbKey, newEntry) + newEntry + else + Set.empty - def put(sbKey: K, typeFullName: String): Set[String] = - put(sbKey, Set(typeFullName)) + def put(sbKey: K, typeFullName: String): Set[String] = + put(sbKey, Set(typeFullName)) - def put(node: AstNode, typeFullNames: Set[String]): Set[String] = keyFromNode(node) match - case Some(key) => put(key, typeFullNames) - case None => Set.empty + def put(node: AstNode, typeFullNames: Set[String]): Set[String] = keyFromNode(node) match + case Some(key) => put(key, typeFullNames) + case None => Set.empty - def append(node: AstNode, typeFullName: String): Set[String] = - append(node, Set(typeFullName)) + def append(node: AstNode, typeFullName: String): Set[String] = + append(node, Set(typeFullName)) - def append(node: K, typeFullName: String): Set[String] = - append(node, Set(typeFullName)) + def append(node: K, typeFullName: String): Set[String] = + append(node, Set(typeFullName)) - def append(node: AstNode, typeFullNames: Set[String]): Set[String] = keyFromNode(node) match - case Some(key) => append(key, typeFullNames) - case None => Set.empty + def append(node: AstNode, typeFullNames: Set[String]): Set[String] = keyFromNode(node) match + case Some(key) => append(key, typeFullNames) + case None => Set.empty - def append(sbKey: K, typeFullNames: Set[String]): Set[String] = - table.get(sbKey) match - case Some(ts) if ts == typeFullNames => ts - case Some(ts) if typeFullNames.nonEmpty => put(sbKey, coalesce(ts, typeFullNames)) - case None if typeFullNames.nonEmpty => put(sbKey, coalesce(Set.empty, typeFullNames)) - case _ => Set.empty + def append(sbKey: K, typeFullNames: Set[String]): Set[String] = + table.get(sbKey) match + case Some(ts) if ts == typeFullNames => ts + case Some(ts) if typeFullNames.nonEmpty => put(sbKey, coalesce(ts, typeFullNames)) + case None if typeFullNames.nonEmpty => put(sbKey, coalesce(Set.empty, typeFullNames)) + case _ => Set.empty - def contains(sbKey: K): Boolean = table.contains(sbKey) + def contains(sbKey: K): Boolean = table.contains(sbKey) - def contains(node: AstNode): Boolean = keyFromNode(node) match - case Some(key) => contains(key) - case None => false + def contains(node: AstNode): Boolean = keyFromNode(node) match + case Some(key) => contains(key) + case None => false - def get(sbKey: K): Set[String] = table.getOrElse(sbKey, Set.empty) + def get(sbKey: K): Set[String] = table.getOrElse(sbKey, Set.empty) - def get(node: AstNode): Set[String] = keyFromNode(node) match - case Some(key) => get(key) - case None => Set.empty + def get(node: AstNode): Set[String] = keyFromNode(node) match + case Some(key) => get(key) + case None => Set.empty - def remove(sbKey: K): Set[String] = table.remove(sbKey).getOrElse(Set.empty) + def remove(sbKey: K): Set[String] = table.remove(sbKey).getOrElse(Set.empty) - def remove(node: AstNode): Set[String] = keyFromNode(node) match - case Some(key) => remove(key) - case None => Set.empty + def remove(node: AstNode): Set[String] = keyFromNode(node) match + case Some(key) => remove(key) + case None => Set.empty - def view: MapView[K, Set[String]] = table.view + def view: MapView[K, Set[String]] = table.view - def clear(): Unit = table.clear() + def clear(): Unit = table.clear() end SymbolTable diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/frontend/TypeNodePass.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/frontend/TypeNodePass.scala index 4e060c2c..f5c3b9c9 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/frontend/TypeNodePass.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/frontend/TypeNodePass.scala @@ -24,71 +24,71 @@ class TypeNodePass private ( getTypesFromCpg: Boolean ) extends CpgPass(cpg, "types", keyPool): - private def getTypeDeclTypes(): mutable.Set[String] = - val typeDeclTypes = mutable.Set[String]() - cpg.typeDecl.foreach { typeDecl => - typeDeclTypes += typeDecl.fullName - typeDeclTypes ++= typeDecl.inheritsFromTypeFullName - } - typeDeclTypes + private def getTypeDeclTypes(): mutable.Set[String] = + val typeDeclTypes = mutable.Set[String]() + cpg.typeDecl.foreach { typeDecl => + typeDeclTypes += typeDecl.fullName + typeDeclTypes ++= typeDecl.inheritsFromTypeFullName + } + typeDeclTypes - def getTypeFullNamesFromCpg(): Set[String] = - cpg.all - .map(_.property(PropertyNames.TYPE_FULL_NAME)) - .filter(_ != null) - .map(_.toString) - .toSet + def getTypeFullNamesFromCpg(): Set[String] = + cpg.all + .map(_.property(PropertyNames.TYPE_FULL_NAME)) + .filter(_ != null) + .map(_.toString) + .toSet - override def run(diffGraph: DiffGraphBuilder): Unit = - val typeFullNameValues = - if getTypesFromCpg then - getTypeFullNamesFromCpg() - else - registeredTypes.toSet + override def run(diffGraph: DiffGraphBuilder): Unit = + val typeFullNameValues = + if getTypesFromCpg then + getTypeFullNamesFromCpg() + else + registeredTypes.toSet - val usedTypesSet = getTypeDeclTypes() ++ typeFullNameValues - usedTypesSet.remove("") - val usedTypes = usedTypesSet.filterInPlace( - !_.endsWith(NamespaceTraversal.globalNamespaceName) - ).toArray.sorted + val usedTypesSet = getTypeDeclTypes() ++ typeFullNameValues + usedTypesSet.remove("") + val usedTypes = usedTypesSet.filterInPlace( + !_.endsWith(NamespaceTraversal.globalNamespaceName) + ).toArray.sorted - diffGraph.addNode( - NewType() - .name("ANY") - .fullName("ANY") - .typeDeclFullName("ANY") - ) + diffGraph.addNode( + NewType() + .name("ANY") + .fullName("ANY") + .typeDeclFullName("ANY") + ) - usedTypes.foreach { typeName => - val shortName = fullToShortName(typeName) - val node = NewType() - .name(shortName) - .fullName(typeName) - .typeDeclFullName(typeName) - diffGraph.addNode(node) - } - end run + usedTypes.foreach { typeName => + val shortName = fullToShortName(typeName) + val node = NewType() + .name(shortName) + .fullName(typeName) + .typeDeclFullName(typeName) + diffGraph.addNode(node) + } + end run end TypeNodePass object TypeNodePass: - // Lambda typeDecl type names fit the structure - // `a.b.c.d.ClassName.lambda$method$name:returnType(paramTypes)` - // so this regex works by greedily matching the package and class names - // at the start and cutting off the matched group before the signature. - private val lambdaTypeRegex = raw".*\.(.*):.*\(.*\)".r + // Lambda typeDecl type names fit the structure + // `a.b.c.d.ClassName.lambda$method$name:returnType(paramTypes)` + // so this regex works by greedily matching the package and class names + // at the start and cutting off the matched group before the signature. + private val lambdaTypeRegex = raw".*\.(.*):.*\(.*\)".r - def withTypesFromCpg(cpg: Cpg, keyPool: Option[KeyPool] = None): TypeNodePass = - new TypeNodePass(Nil, cpg, keyPool, getTypesFromCpg = true) + def withTypesFromCpg(cpg: Cpg, keyPool: Option[KeyPool] = None): TypeNodePass = + new TypeNodePass(Nil, cpg, keyPool, getTypesFromCpg = true) - def withRegisteredTypes( - registeredTypes: List[String], - cpg: Cpg, - keyPool: Option[KeyPool] = None - ): TypeNodePass = - new TypeNodePass(registeredTypes, cpg, keyPool, getTypesFromCpg = false) + def withRegisteredTypes( + registeredTypes: List[String], + cpg: Cpg, + keyPool: Option[KeyPool] = None + ): TypeNodePass = + new TypeNodePass(registeredTypes, cpg, keyPool, getTypesFromCpg = false) - def fullToShortName(typeName: String): String = - typeName match - case lambdaTypeRegex(methodName) => methodName - case _ => typeName.split('.').lastOption.getOrElse(typeName) + def fullToShortName(typeName: String): String = + typeName match + case lambdaTypeRegex(methodName) => methodName + case _ => typeName.split('.').lastOption.getOrElse(typeName) end TypeNodePass diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/frontend/XConfigFileCreationPass.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/frontend/XConfigFileCreationPass.scala index 6568f7c6..eee14d66 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/frontend/XConfigFileCreationPass.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/frontend/XConfigFileCreationPass.scala @@ -16,55 +16,55 @@ import scala.util.{Failure, Success, Try} */ abstract class XConfigFileCreationPass(cpg: Cpg) extends ConcurrentWriterCpgPass[File](cpg): - private val logger = LoggerFactory.getLogger(this.getClass) + private val logger = LoggerFactory.getLogger(this.getClass) - // File filters to override by the implementing class - protected val configFileFilters: List[File => Boolean] + // File filters to override by the implementing class + protected val configFileFilters: List[File => Boolean] - private val rootDir = cpg.metaData.root.headOption.getOrElse("") + private val rootDir = cpg.metaData.root.headOption.getOrElse("") - override def generateParts(): Array[File] = - if rootDir.isBlank then - logger.debug("Unable to recover project directory for configuration file pass.") - Array.empty - else - Try(File(rootDir)) match - case Success(file) if file.isDirectory => - file.listRecursively - .filter(isConfigFile) - .toArray + override def generateParts(): Array[File] = + if rootDir.isBlank then + logger.debug("Unable to recover project directory for configuration file pass.") + Array.empty + else + Try(File(rootDir)) match + case Success(file) if file.isDirectory => + file.listRecursively + .filter(isConfigFile) + .toArray - case Success(file) if isConfigFile(file) => - Array(file) + case Success(file) if isConfigFile(file) => + Array(file) - case _ => Array.empty + case _ => Array.empty - override def runOnPart(diffGraph: DiffGraphBuilder, file: File): Unit = - Try(IOUtils.readEntireFile(file.path)) match - case Success(content) => - val name = configFileName(file) - val configNode = NewConfigFile().name(name).content(content) - logger.debug(s"Adding config file $name") - diffGraph.addNode(configNode) + override def runOnPart(diffGraph: DiffGraphBuilder, file: File): Unit = + Try(IOUtils.readEntireFile(file.path)) match + case Success(content) => + val name = configFileName(file) + val configNode = NewConfigFile().name(name).content(content) + logger.debug(s"Adding config file $name") + diffGraph.addNode(configNode) - case Failure(error) => - logger.debug(s"Unable to create config file node for ${file.canonicalPath}: $error") + case Failure(error) => + logger.debug(s"Unable to create config file node for ${file.canonicalPath}: $error") - private def configFileName(configFile: File): String = - Try(Paths.get(rootDir).toAbsolutePath) - .map(_.relativize(configFile.path.toAbsolutePath).toString) - .orElse(Try(configFile.pathAsString)) - .getOrElse(configFile.name) + private def configFileName(configFile: File): String = + Try(Paths.get(rootDir).toAbsolutePath) + .map(_.relativize(configFile.path.toAbsolutePath).toString) + .orElse(Try(configFile.pathAsString)) + .getOrElse(configFile.name) - protected def extensionFilter(extension: String)(file: File): Boolean = - file.extension.contains(extension) + protected def extensionFilter(extension: String)(file: File): Boolean = + file.extension.contains(extension) - protected def pathEndFilter(pathEnd: String)(file: File): Boolean = - file.canonicalPath.endsWith(pathEnd) + protected def pathEndFilter(pathEnd: String)(file: File): Boolean = + file.canonicalPath.endsWith(pathEnd) - protected def pathRegexFilter(pathRegex: String)(file: File): Boolean = - file.canonicalPath.matches(pathRegex) + protected def pathRegexFilter(pathRegex: String)(file: File): Boolean = + file.canonicalPath.matches(pathRegex) - private def isConfigFile(file: File): Boolean = - configFileFilters.exists(predicate => predicate(file)) + private def isConfigFile(file: File): Boolean = + configFileFilters.exists(predicate => predicate(file)) end XConfigFileCreationPass diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/frontend/XImportResolverPass.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/frontend/XImportResolverPass.scala index d803ca86..a1615e55 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/frontend/XImportResolverPass.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/frontend/XImportResolverPass.scala @@ -12,119 +12,119 @@ import java.io.File as JFile abstract class XImportResolverPass(cpg: Cpg) extends ConcurrentWriterCpgPass[Import](cpg): - protected val logger: Logger = LoggerFactory.getLogger(this.getClass) - protected val codeRoot: String = cpg.metaData.root.headOption.getOrElse(JFile.separator) - - override def generateParts(): Array[Import] = cpg.imports.toArray - - override def runOnPart(builder: DiffGraphBuilder, part: Import): Unit = for - call <- part.call - fileName = call.file.name.headOption.getOrElse("").stripPrefix(codeRoot) - importedAs <- part.importedAs - importedEntity <- part.importedEntity - do - optionalResolveImport(fileName, call, importedEntity, importedAs, builder) - - protected def optionalResolveImport( - fileName: String, - importCall: Call, - importedEntity: String, - importedAs: String, - diffGraph: DiffGraphBuilder - ): Unit - - protected def resolvedImportToTag( - x: ResolvedImport, - importCall: Call, - diffGraph: DiffGraphBuilder - ): Unit = - importCall.start.newTagNodePair(x.label, x.serialize).store()(diffGraph) + protected val logger: Logger = LoggerFactory.getLogger(this.getClass) + protected val codeRoot: String = cpg.metaData.root.headOption.getOrElse(JFile.separator) + + override def generateParts(): Array[Import] = cpg.imports.toArray + + override def runOnPart(builder: DiffGraphBuilder, part: Import): Unit = for + call <- part.call + fileName = call.file.name.headOption.getOrElse("").stripPrefix(codeRoot) + importedAs <- part.importedAs + importedEntity <- part.importedEntity + do + optionalResolveImport(fileName, call, importedEntity, importedAs, builder) + + protected def optionalResolveImport( + fileName: String, + importCall: Call, + importedEntity: String, + importedAs: String, + diffGraph: DiffGraphBuilder + ): Unit + + protected def resolvedImportToTag( + x: ResolvedImport, + importCall: Call, + diffGraph: DiffGraphBuilder + ): Unit = + importCall.start.newTagNodePair(x.label, x.serialize).store()(diffGraph) end XImportResolverPass object ImportsPass: - sealed trait ResolvedImport: - def label: String - - def serialize: String - - implicit class TagToResolvedImportExt(traversal: Iterator[Tag]): - def toResolvedImport: Iterator[ResolvedImport] = - traversal.flatMap(ResolvedImport.tagToResolvedImport) - - object ResolvedImport: - - val RESOLVED_METHOD = "RESOLVED_METHOD" - val RESOLVED_TYPE_DECL = "RESOLVED_TYPE_DECL" - val RESOLVED_MEMBER = "RESOLVED_MEMBER" - val UNKNOWN_METHOD = "UNKNOWN_METHOD" - val UNKNOWN_TYPE_DECL = "UNKNOWN_TYPE_DECL" - val UNKNOWN_IMPORT = "UNKNOWN_IMPORT" - - val OPT_FULL_NAME = "FULL_NAME" - val OPT_ALIAS = "ALIAS" - val OPT_RECEIVER = "RECEIVER" - val OPT_BASE_PATH = "BASE_PATH" - val OPT_NAME = "NAME" - - def tagToResolvedImport(tag: Tag): Option[ResolvedImport] = Option(tag.name match - case RESOLVED_METHOD => - val opts = valueToOptions(tag.value) - ResolvedMethod(opts(OPT_FULL_NAME), opts(OPT_ALIAS), opts.get(OPT_RECEIVER)) - case RESOLVED_TYPE_DECL => ResolvedTypeDecl(tag.value) - case RESOLVED_MEMBER => - val opts = valueToOptions(tag.value) - ResolvedMember(opts(OPT_BASE_PATH), opts(OPT_NAME)) - case UNKNOWN_METHOD => - val opts = valueToOptions(tag.value) - UnknownMethod(opts(OPT_FULL_NAME), opts(OPT_ALIAS), opts.get(OPT_RECEIVER)) - case UNKNOWN_TYPE_DECL => UnknownTypeDecl(tag.value) - case UNKNOWN_IMPORT => UnknownImport(tag.value) - case _ => null - ) - - private def valueToOptions(x: String): Map[String, String] = - x.split(',').grouped(2).map(xs => xs(0) -> xs(1)).toMap - end ResolvedImport - - case class ResolvedMethod( - fullName: String, - alias: String, - receiver: Option[String] = None, - override val label: String = RESOLVED_METHOD - ) extends ResolvedImport: - override def serialize: String = - s"$OPT_FULL_NAME,$fullName,$OPT_ALIAS,$alias" + receiver.map(r => - s",$OPT_RECEIVER,$r" - ).getOrElse("") - - case class ResolvedTypeDecl(fullName: String, override val label: String = RESOLVED_TYPE_DECL) - extends ResolvedImport: - override def serialize: String = fullName - - case class ResolvedMember( - basePath: String, - memberName: String, - override val label: String = RESOLVED_MEMBER - ) extends ResolvedImport: - override def serialize: String = s"$OPT_BASE_PATH,$basePath,$OPT_NAME,$memberName" - - case class UnknownMethod( - fullName: String, - alias: String, - receiver: Option[String] = None, - override val label: String = UNKNOWN_METHOD - ) extends ResolvedImport: - override def serialize: String = - s"$OPT_FULL_NAME,$fullName,$OPT_ALIAS,$alias" + receiver.map(r => - s",$OPT_RECEIVER,$r" - ).getOrElse("") - - case class UnknownTypeDecl(fullName: String, override val label: String = UNKNOWN_TYPE_DECL) - extends ResolvedImport: - override def serialize: String = fullName - - case class UnknownImport(path: String, override val label: String = UNKNOWN_IMPORT) - extends ResolvedImport: - override def serialize: String = path + sealed trait ResolvedImport: + def label: String + + def serialize: String + + implicit class TagToResolvedImportExt(traversal: Iterator[Tag]): + def toResolvedImport: Iterator[ResolvedImport] = + traversal.flatMap(ResolvedImport.tagToResolvedImport) + + object ResolvedImport: + + val RESOLVED_METHOD = "RESOLVED_METHOD" + val RESOLVED_TYPE_DECL = "RESOLVED_TYPE_DECL" + val RESOLVED_MEMBER = "RESOLVED_MEMBER" + val UNKNOWN_METHOD = "UNKNOWN_METHOD" + val UNKNOWN_TYPE_DECL = "UNKNOWN_TYPE_DECL" + val UNKNOWN_IMPORT = "UNKNOWN_IMPORT" + + val OPT_FULL_NAME = "FULL_NAME" + val OPT_ALIAS = "ALIAS" + val OPT_RECEIVER = "RECEIVER" + val OPT_BASE_PATH = "BASE_PATH" + val OPT_NAME = "NAME" + + def tagToResolvedImport(tag: Tag): Option[ResolvedImport] = Option(tag.name match + case RESOLVED_METHOD => + val opts = valueToOptions(tag.value) + ResolvedMethod(opts(OPT_FULL_NAME), opts(OPT_ALIAS), opts.get(OPT_RECEIVER)) + case RESOLVED_TYPE_DECL => ResolvedTypeDecl(tag.value) + case RESOLVED_MEMBER => + val opts = valueToOptions(tag.value) + ResolvedMember(opts(OPT_BASE_PATH), opts(OPT_NAME)) + case UNKNOWN_METHOD => + val opts = valueToOptions(tag.value) + UnknownMethod(opts(OPT_FULL_NAME), opts(OPT_ALIAS), opts.get(OPT_RECEIVER)) + case UNKNOWN_TYPE_DECL => UnknownTypeDecl(tag.value) + case UNKNOWN_IMPORT => UnknownImport(tag.value) + case _ => null + ) + + private def valueToOptions(x: String): Map[String, String] = + x.split(',').grouped(2).map(xs => xs(0) -> xs(1)).toMap + end ResolvedImport + + case class ResolvedMethod( + fullName: String, + alias: String, + receiver: Option[String] = None, + override val label: String = RESOLVED_METHOD + ) extends ResolvedImport: + override def serialize: String = + s"$OPT_FULL_NAME,$fullName,$OPT_ALIAS,$alias" + receiver.map(r => + s",$OPT_RECEIVER,$r" + ).getOrElse("") + + case class ResolvedTypeDecl(fullName: String, override val label: String = RESOLVED_TYPE_DECL) + extends ResolvedImport: + override def serialize: String = fullName + + case class ResolvedMember( + basePath: String, + memberName: String, + override val label: String = RESOLVED_MEMBER + ) extends ResolvedImport: + override def serialize: String = s"$OPT_BASE_PATH,$basePath,$OPT_NAME,$memberName" + + case class UnknownMethod( + fullName: String, + alias: String, + receiver: Option[String] = None, + override val label: String = UNKNOWN_METHOD + ) extends ResolvedImport: + override def serialize: String = + s"$OPT_FULL_NAME,$fullName,$OPT_ALIAS,$alias" + receiver.map(r => + s",$OPT_RECEIVER,$r" + ).getOrElse("") + + case class UnknownTypeDecl(fullName: String, override val label: String = UNKNOWN_TYPE_DECL) + extends ResolvedImport: + override def serialize: String = fullName + + case class UnknownImport(path: String, override val label: String = UNKNOWN_IMPORT) + extends ResolvedImport: + override def serialize: String = path end ImportsPass diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/frontend/XImportsPass.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/frontend/XImportsPass.scala index fdf7fe14..ae9a5a56 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/frontend/XImportsPass.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/frontend/XImportsPass.scala @@ -9,19 +9,19 @@ import io.shiftleft.semanticcpg.language.operatorextension.OpNodes.Assignment abstract class XImportsPass(cpg: Cpg) extends ConcurrentWriterCpgPass[(Call, Assignment)](cpg): - protected val importCallName: String + protected val importCallName: String - override def generateParts(): Array[(Call, Assignment)] = cpg - .call(importCallName) - .flatMap(importCallToPart) - .toArray + override def generateParts(): Array[(Call, Assignment)] = cpg + .call(importCallName) + .flatMap(importCallToPart) + .toArray - protected def importCallToPart(x: Call): Iterator[(Call, Assignment)] + protected def importCallToPart(x: Call): Iterator[(Call, Assignment)] - override def runOnPart(diffGraph: DiffGraphBuilder, part: (Call, Assignment)): Unit = - val (call, assignment) = part - val importedEntity = importedEntityFromCall(call) - val importedAs = assignment.target.code - createImportNodeAndLink(importedEntity, importedAs, Some(call), diffGraph) + override def runOnPart(diffGraph: DiffGraphBuilder, part: (Call, Assignment)): Unit = + val (call, assignment) = part + val importedEntity = importedEntityFromCall(call) + val importedAs = assignment.target.code + createImportNodeAndLink(importedEntity, importedAs, Some(call), diffGraph) - protected def importedEntityFromCall(call: Call): String + protected def importedEntityFromCall(call: Call): String diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/frontend/XInheritanceFullNamePass.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/frontend/XInheritanceFullNamePass.scala index 79e338e2..8a1a3dc1 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/frontend/XInheritanceFullNamePass.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/frontend/XInheritanceFullNamePass.scala @@ -16,137 +16,137 @@ import java.util.regex.{Matcher, Pattern} */ abstract class XInheritanceFullNamePass(cpg: Cpg) extends ForkJoinParallelCpgPass[TypeDecl](cpg): - protected val pathSep: Char = '.' - protected val fileModuleSep: Char = ':' - protected val moduleName: String - protected val fileExt: String - - private val relativePathPattern = Pattern.compile("^[.]+/?.*") - - override def generateParts(): Array[TypeDecl] = - cpg.typeDecl - .filterNot(t => inheritsNothingOfInterest(t.inheritsFromTypeFullName)) - .toArray - - override def runOnPart(builder: DiffGraphBuilder, source: TypeDecl): Unit = - val resolvedTypeDecls = resolveInheritedTypeFullName(source, builder) - if resolvedTypeDecls.nonEmpty then - val fullNames = resolvedTypeDecls.map(_.fullName) - builder.setNodeProperty(source, PropertyNames.INHERITS_FROM_TYPE_FULL_NAME, fullNames) - cpg.typ.fullNameExact(fullNames*).foreach(tgt => - builder.addEdge(source, tgt, EdgeTypes.INHERITS_FROM) - ) - - protected def inheritsNothingOfInterest(inheritedTypes: Seq[String]): Boolean = - inheritedTypes == Seq("ANY") || inheritedTypes == Seq("object") || inheritedTypes.isEmpty - - private def extractTypeDeclFromNode(node: AstNode): Option[String] = node match - case x: Call if x.isCallForImportOut.nonEmpty => - x.isCallForImportOut.importedEntity.map { - case imp if relativePathPattern.matcher(imp).matches() => - imp.split(pathSep).toList match - case head :: next => - (Paths.get(head).normalize().toString +: next).mkString( - pathSep.toString - ) - case Nil => - Paths.get(imp).normalize().toString - case imp => imp - }.headOption - case x: TypeDecl => Option(x.fullName) - case _ => None - - protected def resolveInheritedTypeFullName( - td: TypeDecl, - builder: DiffGraphBuilder - ): Seq[TypeDeclBase] = - val callsOfInterest = td.file.method.flatMap(_._callViaContainsOut) - val typeDeclsOfInterest = td.file.typeDecl - val qualifiedNamesInScope = (callsOfInterest ++ typeDeclsOfInterest) - .flatMap(extractTypeDeclFromNode) - .filterNot(_.endsWith(moduleName)) - .l - val matchersInScope = qualifiedNamesInScope.map { - case x if x.contains(pathSep) => - val splitName = x.split(pathSep) - splitName.last - case x => x - }.distinct - val validTypeDecls = cpg.typeDecl.nameExact(matchersInScope*).toList - val filteredTypes = - validTypeDecls.filter(vt => td.inheritsFromTypeFullName.contains(vt.name)) - if filteredTypes.isEmpty then - // Usually in the case of inheriting external types - qualifiedNamesInScope - .flatMap { qn => - td.inheritsFromTypeFullName.find(t => - ImportStringHandling.namesIntersect(qn, t, pathSep) - ) match - case Some(path) => Option((qn, path)) - case None => None - } - .map { case (qualifiedName, inheritedNames) => - xTypeFullName(qualifiedName, inheritedNames) - } - .map { case (name, fullName) => createTypeStub(name, fullName, builder) } - else - filteredTypes - end resolveInheritedTypeFullName - - private def createTypeStub( - name: String, - fullName: String, - builder: DiffGraphBuilder - ): TypeDeclBase = - cpg.typeDecl.fullNameExact(fullName).headOption match - case Some(typeDecl) => typeDecl - case None => - val typeDecl = TypeDeclStubCreator.createTypeDeclStub(name, fullName) - builder.addNode(typeDecl) - typeDecl - - /** Converts types in the form `foo.bar.Baz` to `foo/bar.js::program:Baz` and will result in a - * tuple of name and full name. - */ - protected def xTypeFullName(importedType: String, importedPath: String): (String, String) = - val combinedPath = ImportStringHandling.combinedPath(importedType, importedPath, pathSep) - combinedPath.split(pathSep).lastOption match - case Some(tName) => - ( - tName, - combinedPath - .stripSuffix(s"$pathSep$tName") - .replaceAll( - s"${Pattern.quote(pathSep.toString)}", - Matcher.quoteReplacement(File.separator) - ) + - Seq(s"$fileExt$fileModuleSep$moduleName", tName).mkString(pathSep.toString) - ) - case None => (combinedPath, combinedPath) + protected val pathSep: Char = '.' + protected val fileModuleSep: Char = ':' + protected val moduleName: String + protected val fileExt: String + + private val relativePathPattern = Pattern.compile("^[.]+/?.*") + + override def generateParts(): Array[TypeDecl] = + cpg.typeDecl + .filterNot(t => inheritsNothingOfInterest(t.inheritsFromTypeFullName)) + .toArray + + override def runOnPart(builder: DiffGraphBuilder, source: TypeDecl): Unit = + val resolvedTypeDecls = resolveInheritedTypeFullName(source, builder) + if resolvedTypeDecls.nonEmpty then + val fullNames = resolvedTypeDecls.map(_.fullName) + builder.setNodeProperty(source, PropertyNames.INHERITS_FROM_TYPE_FULL_NAME, fullNames) + cpg.typ.fullNameExact(fullNames*).foreach(tgt => + builder.addEdge(source, tgt, EdgeTypes.INHERITS_FROM) + ) + + protected def inheritsNothingOfInterest(inheritedTypes: Seq[String]): Boolean = + inheritedTypes == Seq("ANY") || inheritedTypes == Seq("object") || inheritedTypes.isEmpty + + private def extractTypeDeclFromNode(node: AstNode): Option[String] = node match + case x: Call if x.isCallForImportOut.nonEmpty => + x.isCallForImportOut.importedEntity.map { + case imp if relativePathPattern.matcher(imp).matches() => + imp.split(pathSep).toList match + case head :: next => + (Paths.get(head).normalize().toString +: next).mkString( + pathSep.toString + ) + case Nil => + Paths.get(imp).normalize().toString + case imp => imp + }.headOption + case x: TypeDecl => Option(x.fullName) + case _ => None + + protected def resolveInheritedTypeFullName( + td: TypeDecl, + builder: DiffGraphBuilder + ): Seq[TypeDeclBase] = + val callsOfInterest = td.file.method.flatMap(_._callViaContainsOut) + val typeDeclsOfInterest = td.file.typeDecl + val qualifiedNamesInScope = (callsOfInterest ++ typeDeclsOfInterest) + .flatMap(extractTypeDeclFromNode) + .filterNot(_.endsWith(moduleName)) + .l + val matchersInScope = qualifiedNamesInScope.map { + case x if x.contains(pathSep) => + val splitName = x.split(pathSep) + splitName.last + case x => x + }.distinct + val validTypeDecls = cpg.typeDecl.nameExact(matchersInScope*).toList + val filteredTypes = + validTypeDecls.filter(vt => td.inheritsFromTypeFullName.contains(vt.name)) + if filteredTypes.isEmpty then + // Usually in the case of inheriting external types + qualifiedNamesInScope + .flatMap { qn => + td.inheritsFromTypeFullName.find(t => + ImportStringHandling.namesIntersect(qn, t, pathSep) + ) match + case Some(path) => Option((qn, path)) + case None => None + } + .map { case (qualifiedName, inheritedNames) => + xTypeFullName(qualifiedName, inheritedNames) + } + .map { case (name, fullName) => createTypeStub(name, fullName, builder) } + else + filteredTypes + end resolveInheritedTypeFullName + + private def createTypeStub( + name: String, + fullName: String, + builder: DiffGraphBuilder + ): TypeDeclBase = + cpg.typeDecl.fullNameExact(fullName).headOption match + case Some(typeDecl) => typeDecl + case None => + val typeDecl = TypeDeclStubCreator.createTypeDeclStub(name, fullName) + builder.addNode(typeDecl) + typeDecl + + /** Converts types in the form `foo.bar.Baz` to `foo/bar.js::program:Baz` and will result in a + * tuple of name and full name. + */ + protected def xTypeFullName(importedType: String, importedPath: String): (String, String) = + val combinedPath = ImportStringHandling.combinedPath(importedType, importedPath, pathSep) + combinedPath.split(pathSep).lastOption match + case Some(tName) => + ( + tName, + combinedPath + .stripSuffix(s"$pathSep$tName") + .replaceAll( + s"${Pattern.quote(pathSep.toString)}", + Matcher.quoteReplacement(File.separator) + ) + + Seq(s"$fileExt$fileModuleSep$moduleName", tName).mkString(pathSep.toString) + ) + case None => (combinedPath, combinedPath) end XInheritanceFullNamePass object ImportStringHandling: - def namesIntersect(a: String, b: String, pathSep: Char = '.'): Boolean = - val (as, bs, intersect) = splitAndIntersect(a, b, pathSep) - intersect.nonEmpty && (as.endsWith(intersect) || bs.endsWith(intersect)) - - private def splitAndIntersect( - a: String, - b: String, - pathSep: Char = '.' - ): (Seq[String], Seq[String], Seq[String]) = - val as = a.split(pathSep).toIndexedSeq - val bs = b.split(pathSep).toIndexedSeq - (as, bs, as.intersect(bs)) - - def combinedPath(importedType: String, importedPath: String, pathSep: Char = '.'): String = - val (a, b) = - if importedType.length > importedPath.length then - (importedType, importedPath) - else - (importedPath, importedType) - val (as, bs, intersect) = splitAndIntersect(a, b, pathSep) - - if a == importedPath then bs.diff(intersect).concat(as).mkString(pathSep.toString) - else as.diff(intersect).concat(bs).mkString(pathSep.toString) + def namesIntersect(a: String, b: String, pathSep: Char = '.'): Boolean = + val (as, bs, intersect) = splitAndIntersect(a, b, pathSep) + intersect.nonEmpty && (as.endsWith(intersect) || bs.endsWith(intersect)) + + private def splitAndIntersect( + a: String, + b: String, + pathSep: Char = '.' + ): (Seq[String], Seq[String], Seq[String]) = + val as = a.split(pathSep).toIndexedSeq + val bs = b.split(pathSep).toIndexedSeq + (as, bs, as.intersect(bs)) + + def combinedPath(importedType: String, importedPath: String, pathSep: Char = '.'): String = + val (a, b) = + if importedType.length > importedPath.length then + (importedType, importedPath) + else + (importedPath, importedType) + val (as, bs, intersect) = splitAndIntersect(a, b, pathSep) + + if a == importedPath then bs.diff(intersect).concat(as).mkString(pathSep.toString) + else as.diff(intersect).concat(bs).mkString(pathSep.toString) end ImportStringHandling diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/frontend/XTypeHintCallLinker.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/frontend/XTypeHintCallLinker.scala index ea7fd4fc..b14d41ab 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/frontend/XTypeHintCallLinker.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/frontend/XTypeHintCallLinker.scala @@ -23,176 +23,176 @@ import scala.collection.mutable */ abstract class XTypeHintCallLinker(cpg: Cpg) extends CpgPass(cpg): - implicit protected val resolver: NoResolve.type = NoResolve - private val fileNamePattern = Pattern.compile("^(.*(.py|.js|.rb)).*$") - protected val pathSep: Char = '.' - - protected def calls: Iterator[Call] = cpg.call - .nameNot(".*", ".*") - .filter(c => calleeNames(c).nonEmpty && c.callee.isEmpty) - - protected def calleeNames(c: Call): Seq[String] = - c.dynamicTypeHintFullName.filterNot(_.equals("ANY")).distinct - - protected def callees(names: Seq[String]): List[Method] = - cpg.method.fullNameExact(names*).toList - - override def run(builder: DiffGraphBuilder): Unit = linkCalls(builder) - - protected def linkCalls(builder: DiffGraphBuilder): Unit = - val callerAndCallees = calls.map(call => (call, calleeNames(call))).toList - val methodMap = mapMethods(callerAndCallees, builder) - linkCallsToCallees(callerAndCallees, methodMap, builder) - linkSpeculativeNamespaceNodes(methodMap.view.values.collectAll[NewMethod], builder) - - protected def mapMethods( - callerAndCallees: List[(Call, Seq[String])], - builder: DiffGraphBuilder - ): Map[String, MethodBase] = - val methodMap = mutable.HashMap.empty[String, MethodBase] - val newMethods = mutable.HashMap.empty[String, NewMethod] - callerAndCallees.foreach { case (call, methodNames) => - val ms = callees(methodNames) - if ms.nonEmpty then - ms.foreach { m => methodMap.put(m.fullName, m) } - else - val mNames = ms.map(_.fullName).toSet - methodNames - .filterNot(mNames.contains) - .map(fullName => - newMethods.getOrElseUpdate( - fullName, - createMethodStub(fullName, call, builder) - ) - ) - .foreach { m => methodMap.put(m.fullName, m) } - } - methodMap.toMap - end mapMethods - - protected def linkCallsToCallees( - callerAndCallees: List[(Call, Seq[String])], - methodMap: Map[String, MethodBase], - builder: DiffGraphBuilder - ): Unit = - // Link edges to method nodes - callerAndCallees.foreach { case (call, methodNames) => - methodNames - .flatMap(methodMap.get) - .filter(m => call.callee(NoResolve).fullNameExact(m.fullName).isEmpty) - .foreach { m => linkCallToCallee(call, m, builder) } - setCallees(call, methodNames, builder) - } - - def linkCallToCallee(call: Call, method: MethodBase, builder: DiffGraphBuilder): Unit = - builder.addEdge(call, method, EdgeTypes.CALL) - method match - case method: Method => - builder.setNodeProperty( - call, - PropertyNames.TYPE_FULL_NAME, - method.methodReturn.typeFullName - ) - case _ => - - protected def setCallees( - call: Call, - methodNames: Seq[String], - builder: DiffGraphBuilder - ): Unit = - val nonDummyTypes = methodNames.filterNot(isDummyType) - if methodNames.sizeIs == 1 then - builder.setNodeProperty(call, PropertyNames.METHOD_FULL_NAME, methodNames.head) - builder.setNodeProperty( - call, - PropertyNames.DYNAMIC_TYPE_HINT_FULL_NAME, - call.dynamicTypeHintFullName.diff(methodNames) - ) - else if methodNames.sizeIs > 1 && methodNames != nonDummyTypes then - setCallees(call, nonDummyTypes, builder) - - protected def createMethodStub( - methodName: String, - call: Call, - builder: DiffGraphBuilder - ): NewMethod = - // In the case of Python/JS we can use name info to check if, despite the method name might be incorrect, that we - // label the method correctly as internal by finding that the method should belong to an internal file - val matcher = fileNamePattern.matcher(methodName) - val basePath = cpg.metaData.root.head - val isExternal = if matcher.matches() then - val fileName = matcher.group(1) - cpg.file.nameExact(s"$basePath$fileName").isEmpty + implicit protected val resolver: NoResolve.type = NoResolve + private val fileNamePattern = Pattern.compile("^(.*(.py|.js|.rb)).*$") + protected val pathSep: Char = '.' + + protected def calls: Iterator[Call] = cpg.call + .nameNot(".*", ".*") + .filter(c => calleeNames(c).nonEmpty && c.callee.isEmpty) + + protected def calleeNames(c: Call): Seq[String] = + c.dynamicTypeHintFullName.filterNot(_.equals("ANY")).distinct + + protected def callees(names: Seq[String]): List[Method] = + cpg.method.fullNameExact(names*).toList + + override def run(builder: DiffGraphBuilder): Unit = linkCalls(builder) + + protected def linkCalls(builder: DiffGraphBuilder): Unit = + val callerAndCallees = calls.map(call => (call, calleeNames(call))).toList + val methodMap = mapMethods(callerAndCallees, builder) + linkCallsToCallees(callerAndCallees, methodMap, builder) + linkSpeculativeNamespaceNodes(methodMap.view.values.collectAll[NewMethod], builder) + + protected def mapMethods( + callerAndCallees: List[(Call, Seq[String])], + builder: DiffGraphBuilder + ): Map[String, MethodBase] = + val methodMap = mutable.HashMap.empty[String, MethodBase] + val newMethods = mutable.HashMap.empty[String, NewMethod] + callerAndCallees.foreach { case (call, methodNames) => + val ms = callees(methodNames) + if ms.nonEmpty then + ms.foreach { m => methodMap.put(m.fullName, m) } else - true - val name = - if methodName.contains(pathSep) && methodName.length > methodName.lastIndexOf( - pathSep - ) + 1 - then - methodName.substring(methodName.lastIndexOf(pathSep) + 1) - else methodName - createMethodStub(name, methodName, call.argumentOut.size, isExternal, builder) - end createMethodStub - - /** Try to extract a type full name from the method full name, if one exists in the CPG then we - * are lucky and we use it, else we ignore for now. - */ - protected def createMethodStub( - name: String, - fullName: String, - argSize: Int, - isExternal: Boolean, - builder: DiffGraphBuilder - ): NewMethod = - val nameIdx = fullName.lastIndexOf(name) - val default = (NodeTypes.NAMESPACE_BLOCK, XTypeHintCallLinker.namespace) - val (astParentType, astParentFullName) = - if !fullName.isBlank && !fullName.startsWith(" 0 then - cpg.typeDecl - .fullNameExact(fullName.substring(0, nameIdx - 1)) - .map(t => t.label -> t.fullName) - .headOption - .getOrElse(default) - else - default - - MethodStubCreator - .createMethodStub( - name, - fullName, - "", - DispatchTypes.DYNAMIC_DISPATCH.name(), - argSize, - builder, - isExternal, - astParentType, - astParentFullName - ) - end createMethodStub - - /** Once we have connected methods that were speculatively generated and managed to correctly - * link to methods already in the CPG, we link the rest to the "speculativeMethods" namespace - * as a way to show that these may not actually exist. - */ - protected def linkSpeculativeNamespaceNodes( - newMethods: IterableOnce[NewMethod], - builder: DiffGraphBuilder - ): Unit = - val speculativeNamespace = - NewNamespaceBlock().name(XTypeHintCallLinker.namespace).fullName( - XTypeHintCallLinker.namespace - ) - - builder.addNode(speculativeNamespace) - newMethods.iterator - .filter(_.astParentFullName == XTypeHintCallLinker.namespace) - .foreach(m => builder.addEdge(speculativeNamespace, m, EdgeTypes.AST)) + val mNames = ms.map(_.fullName).toSet + methodNames + .filterNot(mNames.contains) + .map(fullName => + newMethods.getOrElseUpdate( + fullName, + createMethodStub(fullName, call, builder) + ) + ) + .foreach { m => methodMap.put(m.fullName, m) } + } + methodMap.toMap + end mapMethods + + protected def linkCallsToCallees( + callerAndCallees: List[(Call, Seq[String])], + methodMap: Map[String, MethodBase], + builder: DiffGraphBuilder + ): Unit = + // Link edges to method nodes + callerAndCallees.foreach { case (call, methodNames) => + methodNames + .flatMap(methodMap.get) + .filter(m => call.callee(NoResolve).fullNameExact(m.fullName).isEmpty) + .foreach { m => linkCallToCallee(call, m, builder) } + setCallees(call, methodNames, builder) + } + + def linkCallToCallee(call: Call, method: MethodBase, builder: DiffGraphBuilder): Unit = + builder.addEdge(call, method, EdgeTypes.CALL) + method match + case method: Method => + builder.setNodeProperty( + call, + PropertyNames.TYPE_FULL_NAME, + method.methodReturn.typeFullName + ) + case _ => + + protected def setCallees( + call: Call, + methodNames: Seq[String], + builder: DiffGraphBuilder + ): Unit = + val nonDummyTypes = methodNames.filterNot(isDummyType) + if methodNames.sizeIs == 1 then + builder.setNodeProperty(call, PropertyNames.METHOD_FULL_NAME, methodNames.head) + builder.setNodeProperty( + call, + PropertyNames.DYNAMIC_TYPE_HINT_FULL_NAME, + call.dynamicTypeHintFullName.diff(methodNames) + ) + else if methodNames.sizeIs > 1 && methodNames != nonDummyTypes then + setCallees(call, nonDummyTypes, builder) + + protected def createMethodStub( + methodName: String, + call: Call, + builder: DiffGraphBuilder + ): NewMethod = + // In the case of Python/JS we can use name info to check if, despite the method name might be incorrect, that we + // label the method correctly as internal by finding that the method should belong to an internal file + val matcher = fileNamePattern.matcher(methodName) + val basePath = cpg.metaData.root.head + val isExternal = if matcher.matches() then + val fileName = matcher.group(1) + cpg.file.nameExact(s"$basePath$fileName").isEmpty + else + true + val name = + if methodName.contains(pathSep) && methodName.length > methodName.lastIndexOf( + pathSep + ) + 1 + then + methodName.substring(methodName.lastIndexOf(pathSep) + 1) + else methodName + createMethodStub(name, methodName, call.argumentOut.size, isExternal, builder) + end createMethodStub + + /** Try to extract a type full name from the method full name, if one exists in the CPG then we + * are lucky and we use it, else we ignore for now. + */ + protected def createMethodStub( + name: String, + fullName: String, + argSize: Int, + isExternal: Boolean, + builder: DiffGraphBuilder + ): NewMethod = + val nameIdx = fullName.lastIndexOf(name) + val default = (NodeTypes.NAMESPACE_BLOCK, XTypeHintCallLinker.namespace) + val (astParentType, astParentFullName) = + if !fullName.isBlank && !fullName.startsWith(" 0 then + cpg.typeDecl + .fullNameExact(fullName.substring(0, nameIdx - 1)) + .map(t => t.label -> t.fullName) + .headOption + .getOrElse(default) + else + default + + MethodStubCreator + .createMethodStub( + name, + fullName, + "", + DispatchTypes.DYNAMIC_DISPATCH.name(), + argSize, + builder, + isExternal, + astParentType, + astParentFullName + ) + end createMethodStub + + /** Once we have connected methods that were speculatively generated and managed to correctly link + * to methods already in the CPG, we link the rest to the "speculativeMethods" namespace as a way + * to show that these may not actually exist. + */ + protected def linkSpeculativeNamespaceNodes( + newMethods: IterableOnce[NewMethod], + builder: DiffGraphBuilder + ): Unit = + val speculativeNamespace = + NewNamespaceBlock().name(XTypeHintCallLinker.namespace).fullName( + XTypeHintCallLinker.namespace + ) + + builder.addNode(speculativeNamespace) + newMethods.iterator + .filter(_.astParentFullName == XTypeHintCallLinker.namespace) + .foreach(m => builder.addEdge(speculativeNamespace, m, EdgeTypes.AST)) end XTypeHintCallLinker object XTypeHintCallLinker: - /** The shared namespace for all methods generated from the type recovery that may not exist - * with this exact full name in reality. - */ - val namespace: String = "" + /** The shared namespace for all methods generated from the type recovery that may not exist with + * this exact full name in reality. + */ + val namespace: String = "" diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/frontend/XTypeRecovery.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/frontend/XTypeRecovery.scala index 18406b20..27619ed2 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/frontend/XTypeRecovery.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/frontend/XTypeRecovery.scala @@ -44,11 +44,11 @@ case class XTypeRecoveryState( changesWereMade: AtomicBoolean = new AtomicBoolean(false), stopEarly: AtomicBoolean ): - lazy val isFinalIteration: Boolean = currentIteration == config.iterations - 1 + lazy val isFinalIteration: Boolean = currentIteration == config.iterations - 1 - lazy val isFirstIteration: Boolean = currentIteration == 0 + lazy val isFirstIteration: Boolean = currentIteration == 0 - def clear(): Unit = isFieldCache.clear() + def clear(): Unit = isFieldCache.clear() /** In order to propagate types across compilation units, but avoid the poor scalability of a * fixed-point algorithm, the number of iterations can be configured using the iterations @@ -65,40 +65,39 @@ abstract class XTypeRecoveryPass[CompilationUnitType <: AstNode]( config: XTypeRecoveryConfig = XTypeRecoveryConfig() ) extends CpgPass(cpg): - override def run(builder: BatchedUpdate.DiffGraphBuilder): Unit = - if config.iterations > 0 then - val stopEarly = new AtomicBoolean(false) - val state = XTypeRecoveryState(config, stopEarly = stopEarly) - try - Iterator.from(0).takeWhile(_ < config.iterations).foreach { i => - val newState = state.copy(currentIteration = i) - generateRecoveryPass(newState).createAndApply() - } - // If dummy values are enabled and we are stopping early, we need one more round to propagate these dummy values - if stopEarly.get() && config.enabledDummyTypes then - generateRecoveryPass( - state.copy(currentIteration = config.iterations - 1) - ).createAndApply() - finally - state.clear() - - protected def generateRecoveryPass(state: XTypeRecoveryState) - : XTypeRecovery[CompilationUnitType] + override def run(builder: BatchedUpdate.DiffGraphBuilder): Unit = + if config.iterations > 0 then + val stopEarly = new AtomicBoolean(false) + val state = XTypeRecoveryState(config, stopEarly = stopEarly) + try + Iterator.from(0).takeWhile(_ < config.iterations).foreach { i => + val newState = state.copy(currentIteration = i) + generateRecoveryPass(newState).createAndApply() + } + // If dummy values are enabled and we are stopping early, we need one more round to propagate these dummy values + if stopEarly.get() && config.enabledDummyTypes then + generateRecoveryPass( + state.copy(currentIteration = config.iterations - 1) + ).createAndApply() + finally + state.clear() + + protected def generateRecoveryPass(state: XTypeRecoveryState): XTypeRecovery[CompilationUnitType] end XTypeRecoveryPass trait TypeRecoveryParserConfig[R <: X2CpgConfig[R]]: - this: R => + this: R => - var disableDummyTypes: Boolean = false - var typePropagationIterations: Int = 2 + var disableDummyTypes: Boolean = false + var typePropagationIterations: Int = 2 - def withDisableDummyTypes(value: Boolean): R = - this.disableDummyTypes = value - this + def withDisableDummyTypes(value: Boolean): R = + this.disableDummyTypes = value + this - def withTypePropagationIterations(value: Int): R = - typePropagationIterations = value - this + def withTypePropagationIterations(value: Int): R = + typePropagationIterations = value + this /** Based on a flow-insensitive static single-assignment symbol-table-style approach. This pass aims * to be fast and deterministic and does not try to converge to some fixed point but rather @@ -128,79 +127,78 @@ trait TypeRecoveryParserConfig[R <: X2CpgConfig[R]]: abstract class XTypeRecovery[CompilationUnitType <: AstNode](cpg: Cpg, state: XTypeRecoveryState) extends CpgPass(cpg): - override def run(builder: DiffGraphBuilder): Unit = - val changesWereMade = compilationUnit - .map(unit => generateRecoveryForCompilationUnitTask(unit, builder).fork()) - .map(_.get) - .reduceOption((a, b) => a || b) - .getOrElse(false) - if !changesWereMade then state.stopEarly.set(true) - - /** @return - * the compilation units as per how the language is compiled. e.g. file. - */ - def compilationUnit: Iterator[CompilationUnitType] - - /** A factory method to generate a [[RecoverForXCompilationUnit]] task with the given - * parameters. - * - * @param unit - * the compilation unit. - * @param builder - * the graph builder. - * @return - * a forkable [[RecoverForXCompilationUnit]] task. - */ - def generateRecoveryForCompilationUnitTask( - unit: CompilationUnitType, - builder: DiffGraphBuilder - ): RecoverForXCompilationUnit[CompilationUnitType] + override def run(builder: DiffGraphBuilder): Unit = + val changesWereMade = compilationUnit + .map(unit => generateRecoveryForCompilationUnitTask(unit, builder).fork()) + .map(_.get) + .reduceOption((a, b) => a || b) + .getOrElse(false) + if !changesWereMade then state.stopEarly.set(true) + + /** @return + * the compilation units as per how the language is compiled. e.g. file. + */ + def compilationUnit: Iterator[CompilationUnitType] + + /** A factory method to generate a [[RecoverForXCompilationUnit]] task with the given parameters. + * + * @param unit + * the compilation unit. + * @param builder + * the graph builder. + * @return + * a forkable [[RecoverForXCompilationUnit]] task. + */ + def generateRecoveryForCompilationUnitTask( + unit: CompilationUnitType, + builder: DiffGraphBuilder + ): RecoverForXCompilationUnit[CompilationUnitType] end XTypeRecovery object XTypeRecovery: - private val logger = LoggerFactory.getLogger(getClass) - - val DummyReturnType = "" - val DummyMemberLoad = "" - val DummyIndexAccess = "" - private lazy val DummyTokens: Set[String] = - Set(DummyReturnType, DummyMemberLoad, DummyIndexAccess) - - def dummyMemberType(prefix: String, memberName: String, sep: Char = '.'): String = - s"$prefix$sep$DummyMemberLoad($memberName)" - - /** Scans the type for placeholder/dummy types. - */ - def isDummyType(typ: String): Boolean = DummyTokens.exists(typ.contains) - - /** Parser options for languages implementing this pass. - */ - def parserOptions[R <: X2CpgConfig[R] & TypeRecoveryParserConfig[R]]: OParser[?, R] = - val builder = OParser.builder[R] - import builder.* - OParser.sequence( - opt[Unit]("no-dummyTypes") - .hidden() - .action((_, c) => c.withDisableDummyTypes(true)) - .text("disable generation of dummy types during type propagation"), - opt[Int]("type-prop-iterations") - .hidden() - .action((x, c) => c.withTypePropagationIterations(x)) - .text("maximum iterations of type propagation") - .validate { x => - if x <= 0 then - logger.debug( - "Disabling type propagation as the given iteration count is <= 0" - ) - else if x == 1 then - logger.debug("Intra-procedural type propagation enabled") - else if x > 5 then - logger.debug(s"Large iteration count of $x will take a while to terminate") - success - } - ) - end parserOptions + private val logger = LoggerFactory.getLogger(getClass) + + val DummyReturnType = "" + val DummyMemberLoad = "" + val DummyIndexAccess = "" + private lazy val DummyTokens: Set[String] = + Set(DummyReturnType, DummyMemberLoad, DummyIndexAccess) + + def dummyMemberType(prefix: String, memberName: String, sep: Char = '.'): String = + s"$prefix$sep$DummyMemberLoad($memberName)" + + /** Scans the type for placeholder/dummy types. + */ + def isDummyType(typ: String): Boolean = DummyTokens.exists(typ.contains) + + /** Parser options for languages implementing this pass. + */ + def parserOptions[R <: X2CpgConfig[R] & TypeRecoveryParserConfig[R]]: OParser[?, R] = + val builder = OParser.builder[R] + import builder.* + OParser.sequence( + opt[Unit]("no-dummyTypes") + .hidden() + .action((_, c) => c.withDisableDummyTypes(true)) + .text("disable generation of dummy types during type propagation"), + opt[Int]("type-prop-iterations") + .hidden() + .action((x, c) => c.withTypePropagationIterations(x)) + .text("maximum iterations of type propagation") + .validate { x => + if x <= 0 then + logger.debug( + "Disabling type propagation as the given iteration count is <= 0" + ) + else if x == 1 then + logger.debug("Intra-procedural type propagation enabled") + else if x > 5 then + logger.debug(s"Large iteration count of $x will take a while to terminate") + success + } + ) + end parserOptions end XTypeRecovery /** Performs type recovery from the root of a compilation unit level @@ -221,1058 +219,1058 @@ abstract class RecoverForXCompilationUnit[CompilationUnitType <: AstNode]( state: XTypeRecoveryState ) extends RecursiveTask[Boolean]: - protected val logger: Logger = LoggerFactory.getLogger(getClass) - - /** Stores type information for local structures that live within this compilation unit, e.g. - * local variables. - */ - protected val symbolTable = new SymbolTable[LocalKey](SBKey.fromNodeToLocalKey) - - /** The root of the target codebase. - */ - protected val codeRoot: String = - cpg.metaData.root.headOption.getOrElse("") + java.io.File.separator - - /** The delimiter used to separate methods/functions in qualified names. - */ - protected val pathSep = '.' - - /** New node tracking set. - */ - protected val addedNodes = mutable.HashSet.empty[String] - - /** For tracking members and the type operations that need to be performed. Since these are - * mostly out of scope locally it helps to track these separately. - * - * // TODO: Potentially a new use for a global table or modification to the symbol table? - */ - protected val newTypesForMembers = mutable.HashMap.empty[Member, Set[String]] - - /** Provides an entrypoint to add known symbols and their possible types. - */ - protected def prepopulateSymbolTable(): Unit = - (cu.ast.isIdentifier ++ cu.ast.isCall ++ cu.ast.isLocal ++ cu.ast.isParameter) - .filter(hasTypes) - .foreach(prepopulateSymbolTableEntry) - - protected def prepopulateSymbolTableEntry(x: AstNode): Unit = x match - case x @ (_: Identifier | _: Local | _: MethodParameterIn) => - symbolTable.append(x, x.getKnownTypes) - case x: Call => symbolTable.append(x, (x.methodFullName +: x.dynamicTypeHintFullName).toSet) - case _ => - - protected def hasTypes(node: AstNode): Boolean = node match - case x: Call if !x.methodFullName.startsWith("") => - !x.methodFullName.toLowerCase().matches("(|any)") - case x => x.getKnownTypes.nonEmpty - - protected def assignments: Iterator[Assignment] = cu match - case x: File => - x.method.flatMap(_._callViaContainsOut).nameExact(Operators.assignment).map( - new OpNodes.Assignment(_) - ) - case x: Method => x.flatMap(_._callViaContainsOut).nameExact(Operators.assignment).map( - new OpNodes.Assignment(_) - ) - case x => x.ast.isCall.nameExact(Operators.assignment).map(new OpNodes.Assignment(_)) - - protected def members: Iterator[Member] = cu.ast.isMember - - protected def returns: Iterator[Return] = cu match - case x: File => x.method.flatMap(_._returnViaContainsOut) - case x: Method => x._returnViaContainsOut - case x => x.ast.isReturn - - protected def importNodes: Iterator[Import] = cu match - case x: File => x.method.flatMap(_._callViaContainsOut).referencedImports - case x: Method => x.file.method.flatMap(_._callViaContainsOut).referencedImports - case x => x.ast.isCall.referencedImports - - override def compute(): Boolean = - try - // Set known aliases that point to imports for local and external methods/modules - importNodes.foreach(visitImport) - // Look at symbols with existing type info - prepopulateSymbolTable() - // Prune import names if the methods exist in the CPG - postVisitImports() - // Populate local symbol table with assignments - assignments.foreach(visitAssignments) - // See if any new information are in the parameters of methods - returns.foreach(visitReturns) - // Persist findings - setTypeInformation() - // Entrypoint for any final changes - postSetTypeInformation() - // Return number of changes - state.changesWereMade.get() - finally - symbolTable.clear() - - private def debugLocation(n: AstNode): String = - val fileName = n.file.name.headOption.getOrElse("").stripPrefix(codeRoot) - val lineNo = n.lineNumber.getOrElse("") - s"$fileName#L$lineNo" - - /** Visits an import and stores references in the symbol table as both an identifier and call. - */ - protected def visitImport(i: Import): Unit = for - resolvedImport <- i.call.tag - alias <- i.importedAs - do - import io.appthreat.x2cpg.passes.frontend.ImportsPass.* - - ResolvedImport.tagToResolvedImport(resolvedImport).foreach { - case ResolvedMethod(fullName, alias, receiver, _) => - symbolTable.append(CallAlias(alias, receiver), fullName) - case ResolvedTypeDecl(fullName, _) => - symbolTable.append(LocalVar(alias), fullName) - case ResolvedMember(basePath, memberName, _) => - val matchingIdentifiers = cpg.method.fullNameExact(basePath).local - val matchingMembers = cpg.typeDecl.fullNameExact(basePath).member - val memberTypes = (matchingMembers ++ matchingIdentifiers) - .nameExact(memberName) - .getKnownTypes - symbolTable.append(LocalVar(alias), memberTypes) - case UnknownMethod(fullName, alias, receiver, _) => - symbolTable.append(CallAlias(alias, receiver), fullName) - case UnknownTypeDecl(fullName, _) => - symbolTable.append(LocalVar(alias), fullName) - case UnknownImport(path, _) => - symbolTable.append(CallAlias(alias), path) - symbolTable.append(LocalVar(alias), path) - } - - /** The initial import setting is over-approximated, so this step checks the CPG for any matches - * and prunes against these findings. If there are no findings, it will leave the table as is. - * The latter is significant for external types or methods. - */ - protected def postVisitImports(): Unit = {} - - /** Using assignment and import information (in the global symbol table), will propagate these - * types in the symbol table. - * - * @param a - * assignment call pointer. - */ - protected def visitAssignments(a: Assignment): Set[String] = - a.argumentOut.l match - case List(i: Identifier, b: Block) => visitIdentifierAssignedToBlock(i, b) - case List(i: Identifier, c: Call) => visitIdentifierAssignedToCall(i, c) - case List(x: Identifier, y: Identifier) => visitIdentifierAssignedToIdentifier(x, y) - case List(i: Identifier, l: Literal) if state.isFirstIteration => - visitIdentifierAssignedToLiteral(i, l) - case List(i: Identifier, m: MethodRef) => visitIdentifierAssignedToMethodRef(i, m) - case List(i: Identifier, t: TypeRef) => visitIdentifierAssignedToTypeRef(i, t) - case List(c: Call, i: Identifier) => visitCallAssignedToIdentifier(c, i) - case List(x: Call, y: Call) => visitCallAssignedToCall(x, y) - case List(c: Call, l: Literal) if state.isFirstIteration => - visitCallAssignedToLiteral(c, l) - case List(c: Call, m: MethodRef) => visitCallAssignedToMethodRef(c, m) - case List(c: Call, b: Block) => visitCallAssignedToBlock(c, b) - case _ => Set.empty - - /** Visits an identifier being assigned to the result of some operation. - */ - protected def visitIdentifierAssignedToBlock(i: Identifier, b: Block): Set[String] = - val blockTypes = visitStatementsInBlock(b, Some(i)) - if blockTypes.nonEmpty then associateTypes(i, blockTypes) - else Set.empty - - /** Visits a call operation being assigned to the result of some operation. - */ - protected def visitCallAssignedToBlock(c: Call, b: Block): Set[String] = - val blockTypes = visitStatementsInBlock(b) - assignTypesToCall(c, blockTypes) - - /** Process each statement but only assign the type of the last statement to the identifier - */ - protected def visitStatementsInBlock( - b: Block, - assignmentTarget: Option[Identifier] = None - ): Set[String] = - b.astChildren - .map { - case x: Call if x.name.startsWith(Operators.assignment) => - visitAssignments(new Assignment(x)) - case x: Call if x.name.startsWith("") && assignmentTarget.isDefined => - visitIdentifierAssignedToOperator(assignmentTarget.get, x, x.name) - case x: Identifier if symbolTable.contains(x) => symbolTable.get(x) - case x: Call if symbolTable.contains(x) => symbolTable.get(x) - case x: Call if x.argument.headOption.exists(symbolTable.contains) => - setCallMethodFullNameFromBase(x) - case x: Block => visitStatementsInBlock(x) - case x: Local => symbolTable.get(x) - case _: ControlStructure => Set.empty[String] - case x => - logger.debug( - s"Unhandled block element ${x.label}:${x.code} @ ${debugLocation(x)}" - ); Set.empty[String] - } - .lastOption - .getOrElse(Set.empty[String]) - - /** Visits an identifier being assigned to a call. This call could be an operation, function - * invocation, or constructor invocation. - */ - protected def visitIdentifierAssignedToCall(i: Identifier, c: Call): Set[String] = - if c.name.startsWith("") then - visitIdentifierAssignedToOperator(i, c, c.name) - else if symbolTable.contains(c) && isConstructor(c) then - visitIdentifierAssignedToConstructor(i, c) - else if symbolTable.contains(c) then - visitIdentifierAssignedToCallRetVal(i, c) - else if c.argument.argumentIndex(0).headOption.exists(symbolTable.contains) then - setCallMethodFullNameFromBase(c) - // Repeat this method now that the call has a type - visitIdentifierAssignedToCall(i, c) - else - // We can try obtain a return type for this call - visitIdentifierAssignedToCallRetVal(i, c) - - /** Visits an identifier being assigned to the value held by another identifier. This is a weak - * copy. - */ - protected def visitIdentifierAssignedToIdentifier(x: Identifier, y: Identifier): Set[String] = - if symbolTable.contains(y) then associateTypes(x, symbolTable.get(y)) - else Set.empty - - /** Will build a call full path using the call base node. This method assumes the base node is - * in the symbol table. - */ - protected def setCallMethodFullNameFromBase(c: Call): Set[String] = - val recTypes = c.argument.headOption - .map { - case x: Call if x.typeFullName != "ANY" => Set(x.typeFullName) - case x: Call => - cpg.method.fullNameExact(c.methodFullName).methodReturn.typeFullNameNot( - "ANY" - ).typeFullName.toSet match - case xs if xs.nonEmpty => xs - case _ => symbolTable.get(x).map(t => - Seq(t, XTypeRecovery.DummyReturnType).mkString(pathSep.toString) - ) - case x => symbolTable.get(x) - } - .getOrElse(Set.empty[String]) - val callTypes = recTypes.map(_.concat(s"$pathSep${c.name}")) - symbolTable.append(c, callTypes) - - /** A heuristic method to determine if a call is a constructor or not. - */ - protected def isConstructor(c: Call): Boolean - - /** A heuristic method to determine if a call name is a constructor or not. - */ - protected def isConstructor(name: String): Boolean - - /** A heuristic method to determine if an identifier may be a field or not. The result means - * that it would be stored in the global symbol table. By default this checks if the identifier - * name matches a member name. - * - * This has found to be an expensive operation accessed often so we have memoized this step. - */ - protected def isField(i: Identifier): Boolean = - state.isFieldCache.getOrElseUpdate( - i.id(), - i.method.typeDecl.member.nameExact(i.name).nonEmpty + protected val logger: Logger = LoggerFactory.getLogger(getClass) + + /** Stores type information for local structures that live within this compilation unit, e.g. + * local variables. + */ + protected val symbolTable = new SymbolTable[LocalKey](SBKey.fromNodeToLocalKey) + + /** The root of the target codebase. + */ + protected val codeRoot: String = + cpg.metaData.root.headOption.getOrElse("") + java.io.File.separator + + /** The delimiter used to separate methods/functions in qualified names. + */ + protected val pathSep = '.' + + /** New node tracking set. + */ + protected val addedNodes = mutable.HashSet.empty[String] + + /** For tracking members and the type operations that need to be performed. Since these are mostly + * out of scope locally it helps to track these separately. + * + * // TODO: Potentially a new use for a global table or modification to the symbol table? + */ + protected val newTypesForMembers = mutable.HashMap.empty[Member, Set[String]] + + /** Provides an entrypoint to add known symbols and their possible types. + */ + protected def prepopulateSymbolTable(): Unit = + (cu.ast.isIdentifier ++ cu.ast.isCall ++ cu.ast.isLocal ++ cu.ast.isParameter) + .filter(hasTypes) + .foreach(prepopulateSymbolTableEntry) + + protected def prepopulateSymbolTableEntry(x: AstNode): Unit = x match + case x @ (_: Identifier | _: Local | _: MethodParameterIn) => + symbolTable.append(x, x.getKnownTypes) + case x: Call => symbolTable.append(x, (x.methodFullName +: x.dynamicTypeHintFullName).toSet) + case _ => + + protected def hasTypes(node: AstNode): Boolean = node match + case x: Call if !x.methodFullName.startsWith("") => + !x.methodFullName.toLowerCase().matches("(|any)") + case x => x.getKnownTypes.nonEmpty + + protected def assignments: Iterator[Assignment] = cu match + case x: File => + x.method.flatMap(_._callViaContainsOut).nameExact(Operators.assignment).map( + new OpNodes.Assignment(_) ) - - /** Associates the types with the identifier. This may sometimes be an identifier that should be - * considered a field which this method uses [[isField]] to determine. - */ - protected def associateTypes(i: Identifier, types: Set[String]): Set[String] = - symbolTable.append(i, types) - - /** Returns the appropriate field parent scope. - */ - protected def getFieldParents(fa: FieldAccess): Set[String] = - val fieldName = getFieldName(fa).split(pathSep).last - cpg.member.nameExact(fieldName).typeDecl.fullName.filterNot(_.contains("ANY")).toSet - - /** Associates the types with the identifier. This may sometimes be an identifier that should be - * considered a field which this method uses [[isField]] to determine. - */ - protected def associateTypes( - symbol: LocalVar, - fa: FieldAccess, - types: Set[String] - ): Set[String] = - fa.astChildren.filterNot(_.code.matches("(this|self)")).headOption.collect { - case fi: FieldIdentifier => - getFieldParents(fa).foreach(t => - persistMemberWithTypeDecl(t, fi.canonicalName, types) - ) - case i: Identifier if isField(i) => - getFieldParents(fa).foreach(t => persistMemberWithTypeDecl(t, i.name, types)) + case x: Method => x.flatMap(_._callViaContainsOut).nameExact(Operators.assignment).map( + new OpNodes.Assignment(_) + ) + case x => x.ast.isCall.nameExact(Operators.assignment).map(new OpNodes.Assignment(_)) + + protected def members: Iterator[Member] = cu.ast.isMember + + protected def returns: Iterator[Return] = cu match + case x: File => x.method.flatMap(_._returnViaContainsOut) + case x: Method => x._returnViaContainsOut + case x => x.ast.isReturn + + protected def importNodes: Iterator[Import] = cu match + case x: File => x.method.flatMap(_._callViaContainsOut).referencedImports + case x: Method => x.file.method.flatMap(_._callViaContainsOut).referencedImports + case x => x.ast.isCall.referencedImports + + override def compute(): Boolean = + try + // Set known aliases that point to imports for local and external methods/modules + importNodes.foreach(visitImport) + // Look at symbols with existing type info + prepopulateSymbolTable() + // Prune import names if the methods exist in the CPG + postVisitImports() + // Populate local symbol table with assignments + assignments.foreach(visitAssignments) + // See if any new information are in the parameters of methods + returns.foreach(visitReturns) + // Persist findings + setTypeInformation() + // Entrypoint for any final changes + postSetTypeInformation() + // Return number of changes + state.changesWereMade.get() + finally + symbolTable.clear() + + private def debugLocation(n: AstNode): String = + val fileName = n.file.name.headOption.getOrElse("").stripPrefix(codeRoot) + val lineNo = n.lineNumber.getOrElse("") + s"$fileName#L$lineNo" + + /** Visits an import and stores references in the symbol table as both an identifier and call. + */ + protected def visitImport(i: Import): Unit = for + resolvedImport <- i.call.tag + alias <- i.importedAs + do + import io.appthreat.x2cpg.passes.frontend.ImportsPass.* + + ResolvedImport.tagToResolvedImport(resolvedImport).foreach { + case ResolvedMethod(fullName, alias, receiver, _) => + symbolTable.append(CallAlias(alias, receiver), fullName) + case ResolvedTypeDecl(fullName, _) => + symbolTable.append(LocalVar(alias), fullName) + case ResolvedMember(basePath, memberName, _) => + val matchingIdentifiers = cpg.method.fullNameExact(basePath).local + val matchingMembers = cpg.typeDecl.fullNameExact(basePath).member + val memberTypes = (matchingMembers ++ matchingIdentifiers) + .nameExact(memberName) + .getKnownTypes + symbolTable.append(LocalVar(alias), memberTypes) + case UnknownMethod(fullName, alias, receiver, _) => + symbolTable.append(CallAlias(alias, receiver), fullName) + case UnknownTypeDecl(fullName, _) => + symbolTable.append(LocalVar(alias), fullName) + case UnknownImport(path, _) => + symbolTable.append(CallAlias(alias), path) + symbolTable.append(LocalVar(alias), path) + } + + /** The initial import setting is over-approximated, so this step checks the CPG for any matches + * and prunes against these findings. If there are no findings, it will leave the table as is. + * The latter is significant for external types or methods. + */ + protected def postVisitImports(): Unit = {} + + /** Using assignment and import information (in the global symbol table), will propagate these + * types in the symbol table. + * + * @param a + * assignment call pointer. + */ + protected def visitAssignments(a: Assignment): Set[String] = + a.argumentOut.l match + case List(i: Identifier, b: Block) => visitIdentifierAssignedToBlock(i, b) + case List(i: Identifier, c: Call) => visitIdentifierAssignedToCall(i, c) + case List(x: Identifier, y: Identifier) => visitIdentifierAssignedToIdentifier(x, y) + case List(i: Identifier, l: Literal) if state.isFirstIteration => + visitIdentifierAssignedToLiteral(i, l) + case List(i: Identifier, m: MethodRef) => visitIdentifierAssignedToMethodRef(i, m) + case List(i: Identifier, t: TypeRef) => visitIdentifierAssignedToTypeRef(i, t) + case List(c: Call, i: Identifier) => visitCallAssignedToIdentifier(c, i) + case List(x: Call, y: Call) => visitCallAssignedToCall(x, y) + case List(c: Call, l: Literal) if state.isFirstIteration => + visitCallAssignedToLiteral(c, l) + case List(c: Call, m: MethodRef) => visitCallAssignedToMethodRef(c, m) + case List(c: Call, b: Block) => visitCallAssignedToBlock(c, b) + case _ => Set.empty + + /** Visits an identifier being assigned to the result of some operation. + */ + protected def visitIdentifierAssignedToBlock(i: Identifier, b: Block): Set[String] = + val blockTypes = visitStatementsInBlock(b, Some(i)) + if blockTypes.nonEmpty then associateTypes(i, blockTypes) + else Set.empty + + /** Visits a call operation being assigned to the result of some operation. + */ + protected def visitCallAssignedToBlock(c: Call, b: Block): Set[String] = + val blockTypes = visitStatementsInBlock(b) + assignTypesToCall(c, blockTypes) + + /** Process each statement but only assign the type of the last statement to the identifier + */ + protected def visitStatementsInBlock( + b: Block, + assignmentTarget: Option[Identifier] = None + ): Set[String] = + b.astChildren + .map { + case x: Call if x.name.startsWith(Operators.assignment) => + visitAssignments(new Assignment(x)) + case x: Call if x.name.startsWith("") && assignmentTarget.isDefined => + visitIdentifierAssignedToOperator(assignmentTarget.get, x, x.name) + case x: Identifier if symbolTable.contains(x) => symbolTable.get(x) + case x: Call if symbolTable.contains(x) => symbolTable.get(x) + case x: Call if x.argument.headOption.exists(symbolTable.contains) => + setCallMethodFullNameFromBase(x) + case x: Block => visitStatementsInBlock(x) + case x: Local => symbolTable.get(x) + case _: ControlStructure => Set.empty[String] + case x => + logger.debug( + s"Unhandled block element ${x.label}:${x.code} @ ${debugLocation(x)}" + ); Set.empty[String] + } + .lastOption + .getOrElse(Set.empty[String]) + + /** Visits an identifier being assigned to a call. This call could be an operation, function + * invocation, or constructor invocation. + */ + protected def visitIdentifierAssignedToCall(i: Identifier, c: Call): Set[String] = + if c.name.startsWith("") then + visitIdentifierAssignedToOperator(i, c, c.name) + else if symbolTable.contains(c) && isConstructor(c) then + visitIdentifierAssignedToConstructor(i, c) + else if symbolTable.contains(c) then + visitIdentifierAssignedToCallRetVal(i, c) + else if c.argument.argumentIndex(0).headOption.exists(symbolTable.contains) then + setCallMethodFullNameFromBase(c) + // Repeat this method now that the call has a type + visitIdentifierAssignedToCall(i, c) + else + // We can try obtain a return type for this call + visitIdentifierAssignedToCallRetVal(i, c) + + /** Visits an identifier being assigned to the value held by another identifier. This is a weak + * copy. + */ + protected def visitIdentifierAssignedToIdentifier(x: Identifier, y: Identifier): Set[String] = + if symbolTable.contains(y) then associateTypes(x, symbolTable.get(y)) + else Set.empty + + /** Will build a call full path using the call base node. This method assumes the base node is in + * the symbol table. + */ + protected def setCallMethodFullNameFromBase(c: Call): Set[String] = + val recTypes = c.argument.headOption + .map { + case x: Call if x.typeFullName != "ANY" => Set(x.typeFullName) + case x: Call => + cpg.method.fullNameExact(c.methodFullName).methodReturn.typeFullNameNot( + "ANY" + ).typeFullName.toSet match + case xs if xs.nonEmpty => xs + case _ => symbolTable.get(x).map(t => + Seq(t, XTypeRecovery.DummyReturnType).mkString(pathSep.toString) + ) + case x => symbolTable.get(x) } - symbolTable.append(symbol, types) - - /** Similar to associateTypes but used in the case where there is some kind of field load. - */ - protected def associateInterproceduralTypes( - i: Identifier, - base: Identifier, - fi: FieldIdentifier, - fieldName: String, - baseTypes: Set[String] - ): Set[String] = - val globalTypes = getFieldBaseType(base, fi) - associateInterproceduralTypes(i, fieldName, fi.canonicalName, globalTypes, baseTypes) - - protected def associateInterproceduralTypes( - i: Identifier, - fieldFullName: String, - fieldName: String, - globalTypes: Set[String], - baseTypes: Set[String] - ): Set[String] = - if globalTypes.nonEmpty then - // We have been able to resolve the type inter-procedurally - associateTypes(i, globalTypes) - else if baseTypes.nonEmpty then - if baseTypes.equals(symbolTable.get(LocalVar(fieldFullName))) then - associateTypes(i, baseTypes) - else - // If not available, use a dummy variable that can be useful for call matching - associateTypes( - i, - baseTypes.map(t => XTypeRecovery.dummyMemberType(t, fieldName, pathSep)) - ) - else - // Assign dummy - val dummyTypes = Set( - XTypeRecovery.dummyMemberType( - fieldFullName.stripSuffix(s"$pathSep$fieldName"), - fieldName, - pathSep - ) + .getOrElse(Set.empty[String]) + val callTypes = recTypes.map(_.concat(s"$pathSep${c.name}")) + symbolTable.append(c, callTypes) + + /** A heuristic method to determine if a call is a constructor or not. + */ + protected def isConstructor(c: Call): Boolean + + /** A heuristic method to determine if a call name is a constructor or not. + */ + protected def isConstructor(name: String): Boolean + + /** A heuristic method to determine if an identifier may be a field or not. The result means that + * it would be stored in the global symbol table. By default this checks if the identifier name + * matches a member name. + * + * This has found to be an expensive operation accessed often so we have memoized this step. + */ + protected def isField(i: Identifier): Boolean = + state.isFieldCache.getOrElseUpdate( + i.id(), + i.method.typeDecl.member.nameExact(i.name).nonEmpty + ) + + /** Associates the types with the identifier. This may sometimes be an identifier that should be + * considered a field which this method uses [[isField]] to determine. + */ + protected def associateTypes(i: Identifier, types: Set[String]): Set[String] = + symbolTable.append(i, types) + + /** Returns the appropriate field parent scope. + */ + protected def getFieldParents(fa: FieldAccess): Set[String] = + val fieldName = getFieldName(fa).split(pathSep).last + cpg.member.nameExact(fieldName).typeDecl.fullName.filterNot(_.contains("ANY")).toSet + + /** Associates the types with the identifier. This may sometimes be an identifier that should be + * considered a field which this method uses [[isField]] to determine. + */ + protected def associateTypes( + symbol: LocalVar, + fa: FieldAccess, + types: Set[String] + ): Set[String] = + fa.astChildren.filterNot(_.code.matches("(this|self)")).headOption.collect { + case fi: FieldIdentifier => + getFieldParents(fa).foreach(t => + persistMemberWithTypeDecl(t, fi.canonicalName, types) ) - associateTypes(i, dummyTypes) - - /** Visits an identifier being assigned to an operator call. - */ - protected def visitIdentifierAssignedToOperator( - i: Identifier, - c: Call, - operation: String - ): Set[String] = - operation match - case Operators.alloc => visitIdentifierAssignedToConstructor(i, c) - case Operators.fieldAccess => visitIdentifierAssignedToFieldLoad(i, new FieldAccess(c)) - case Operators.indexAccess => visitIdentifierAssignedToIndexAccess(i, c) - case Operators.cast => visitIdentifierAssignedToCast(i, c) - case x => - logger.debug(s"Unhandled operation $x (${c.code}) @ ${debugLocation(c)}"); Set.empty - - /** Visits an identifier being assigned to a constructor and attempts to speculate the - * constructor path. - */ - protected def visitIdentifierAssignedToConstructor(i: Identifier, c: Call): Set[String] = - val constructorPaths = - symbolTable.get(c).map(t => t.concat(s"$pathSep${Defines.ConstructorMethodName}")) - associateTypes(i, constructorPaths) - - /** Visits an identifier being assigned to a call's return value. - */ - protected def visitIdentifierAssignedToCallRetVal(i: Identifier, c: Call): Set[String] = - if symbolTable.contains(c) then - val callReturns = methodReturnValues(symbolTable.get(c).toSeq) - associateTypes(i, callReturns) - else if c.argument.exists(_.argumentIndex == 0) then - val callFullNames = (c.argument(0) match - case i: Identifier if symbolTable.contains(LocalVar(i.name)) => - symbolTable.get(LocalVar(i.name)) - case i: Identifier if symbolTable.contains(CallAlias(i.name)) => - symbolTable.get(CallAlias(i.name)) - case _ => Set.empty - ).map(_.concat(s"$pathSep${c.name}")).toSeq - val callReturns = methodReturnValues(callFullNames) - associateTypes(i, callReturns) + case i: Identifier if isField(i) => + getFieldParents(fa).foreach(t => persistMemberWithTypeDecl(t, i.name, types)) + } + symbolTable.append(symbol, types) + + /** Similar to associateTypes but used in the case where there is some kind of field load. + */ + protected def associateInterproceduralTypes( + i: Identifier, + base: Identifier, + fi: FieldIdentifier, + fieldName: String, + baseTypes: Set[String] + ): Set[String] = + val globalTypes = getFieldBaseType(base, fi) + associateInterproceduralTypes(i, fieldName, fi.canonicalName, globalTypes, baseTypes) + + protected def associateInterproceduralTypes( + i: Identifier, + fieldFullName: String, + fieldName: String, + globalTypes: Set[String], + baseTypes: Set[String] + ): Set[String] = + if globalTypes.nonEmpty then + // We have been able to resolve the type inter-procedurally + associateTypes(i, globalTypes) + else if baseTypes.nonEmpty then + if baseTypes.equals(symbolTable.get(LocalVar(fieldFullName))) then + associateTypes(i, baseTypes) else - // Assign dummy value - associateTypes(i, Set(s"${c.name}$pathSep${XTypeRecovery.DummyReturnType}")) - - /** Will attempt to find the return values of a method if in the CPG, otherwise will give a - * dummy value. - */ - protected def methodReturnValues(methodFullNames: Seq[String]): Set[String] = - val rs = cpg.method - .fullNameExact(methodFullNames*) - .methodReturn - .flatMap(mr => mr.typeFullName +: mr.dynamicTypeHintFullName) - .filterNot(_.equals("ANY")) - .toSet - if rs.isEmpty then - methodFullNames.map(_.concat(s"$pathSep${XTypeRecovery.DummyReturnType}")).toSet - else rs - - /** Will handle literal value assignments. Override if special handling is required. - */ - protected def visitIdentifierAssignedToLiteral(i: Identifier, l: Literal): Set[String] = - associateTypes(i, getLiteralType(l)) - - /** Not all frontends populate typeFullName for literals so we allow this to be - * overridden. - */ - protected def getLiteralType(l: Literal): Set[String] = Set(l.typeFullName) - - /** Will handle an identifier holding a function pointer. - */ - protected def visitIdentifierAssignedToMethodRef( - i: Identifier, - m: MethodRef, - rec: Option[String] = None - ): Set[String] = - symbolTable.append(CallAlias(i.name, rec), Set(m.methodFullName)) - - /** Will handle an identifier holding a type pointer. - */ - protected def visitIdentifierAssignedToTypeRef( - i: Identifier, - t: TypeRef, - rec: Option[String] = None - ): Set[String] = - symbolTable.append(CallAlias(i.name, rec), Set(t.typeFullName)) - - /** Visits a call assigned to an identifier. This is often when there are operators involved. - */ - protected def visitCallAssignedToIdentifier(c: Call, i: Identifier): Set[String] = - val rhsTypes = symbolTable.get(i) - assignTypesToCall(c, rhsTypes) - - /** Visits a call assigned to the return value of a call. This is often when there are operators - * involved. - */ - protected def visitCallAssignedToCall(x: Call, y: Call): Set[String] = - assignTypesToCall(x, getTypesFromCall(y)) - - /** Given a call operation, will attempt to retrieve types from it. - */ - protected def getTypesFromCall(c: Call): Set[String] = c.name match - case Operators.fieldAccess => symbolTable.get(LocalVar(getFieldName(new FieldAccess(c)))) - case _ if symbolTable.contains(c) => symbolTable.get(c) - case Operators.indexAccess => getIndexAccessTypes(c) - case n => - logger.debug(s"Unknown RHS call type '$n' @ ${debugLocation(c)}") - Set.empty[String] - - /** Given a LHS call, will retrieve its symbol to the given types. - */ - protected def assignTypesToCall(x: Call, types: Set[String]): Set[String] = - if types.nonEmpty then - getSymbolFromCall(x) match - case (lhs, globalKeys) if globalKeys.nonEmpty => - globalKeys.foreach { (fieldVar: FieldPath) => - persistMemberWithTypeDecl( - fieldVar.compUnitFullName, - fieldVar.identifier, - types - ) - } - symbolTable.append(lhs, types) - case (lhs, _) => symbolTable.append(lhs, types) - else Set.empty - - /** Will attempt to retrieve index access types otherwise will return dummy value. - */ - protected def getIndexAccessTypes(ia: Call): Set[String] = indexAccessToCollectionVar(ia) match - case Some(cVar) if symbolTable.contains(cVar) => - symbolTable.get(cVar) - case Some(cVar) if symbolTable.contains(LocalVar(cVar.identifier)) => - symbolTable.get(LocalVar(cVar.identifier)).map(x => - s"$x$pathSep${XTypeRecovery.DummyIndexAccess}" - ) - case _ => Set.empty - - /** Convenience class for transporting field names. - * @param compUnitFullName - * qualified path to base type holding the member. - * @param identifier - * the member name. - */ - case class FieldPath(compUnitFullName: String, identifier: String) - - /** Tries to identify the underlying symbol from the call operation as it is used on the LHS of - * an assignment. The second element is a list of any associated global keys if applicable. - */ - protected def getSymbolFromCall(c: Call): (LocalKey, Set[FieldPath]) = c.name match - case Operators.fieldAccess => - val fa = new FieldAccess(c) - val fieldName = getFieldName(fa) - (LocalVar(fieldName), getFieldParents(fa).map(fp => FieldPath(fp, fieldName))) - case Operators.indexAccess => - (indexAccessToCollectionVar(c).getOrElse(LocalVar(c.name)), Set.empty) + // If not available, use a dummy variable that can be useful for call matching + associateTypes( + i, + baseTypes.map(t => XTypeRecovery.dummyMemberType(t, fieldName, pathSep)) + ) + else + // Assign dummy + val dummyTypes = Set( + XTypeRecovery.dummyMemberType( + fieldFullName.stripSuffix(s"$pathSep$fieldName"), + fieldName, + pathSep + ) + ) + associateTypes(i, dummyTypes) + + /** Visits an identifier being assigned to an operator call. + */ + protected def visitIdentifierAssignedToOperator( + i: Identifier, + c: Call, + operation: String + ): Set[String] = + operation match + case Operators.alloc => visitIdentifierAssignedToConstructor(i, c) + case Operators.fieldAccess => visitIdentifierAssignedToFieldLoad(i, new FieldAccess(c)) + case Operators.indexAccess => visitIdentifierAssignedToIndexAccess(i, c) + case Operators.cast => visitIdentifierAssignedToCast(i, c) case x => - logger.debug(s"Using default LHS call name '$x' @ ${debugLocation(c)}") - (LocalVar(c.name), Set.empty) - - /** Extracts a string representation of the name of the field within this field access. - */ - protected def getFieldName(fa: FieldAccess, prefix: String = "", suffix: String = ""): String = - def wrapName(n: String) = - val sb = new mutable.StringBuilder() - if prefix.nonEmpty then sb.append(s"$prefix$pathSep") - sb.append(n) - if suffix.nonEmpty then sb.append(s"$pathSep$suffix") - sb.toString() - - lazy val typesFromBaseCall = fa.argumentOut.headOption match - case Some(call: Call) => getTypesFromCall(call) - case _ => Set.empty[String] - - fa.argumentOut.l match - case ::(i: Identifier, ::(f: FieldIdentifier, _)) if i.name.matches("(self|this)") => - wrapName(f.canonicalName) - case ::(i: Identifier, ::(f: FieldIdentifier, _)) => - wrapName(s"${i.name}$pathSep${f.canonicalName}") - case ::(c: Call, ::(f: FieldIdentifier, _)) if c.name.equals(Operators.fieldAccess) => - wrapName(getFieldName(new FieldAccess(c), suffix = f.canonicalName)) - case ::(_: Call, ::(f: FieldIdentifier, _)) if typesFromBaseCall.nonEmpty => - // TODO: Handle this case better - wrapName(s"${typesFromBaseCall.head}$pathSep${f.canonicalName}") - case ::(f: FieldIdentifier, ::(c: Call, _)) if c.name.equals(Operators.fieldAccess) => - wrapName(getFieldName(new FieldAccess(c), prefix = f.canonicalName)) - case ::(c: Call, ::(f: FieldIdentifier, _)) => - // TODO: Handle this case better - val callCode = - if c.code.contains("(") then c.code.substring(c.code.indexOf("(")) else c.code - XTypeRecovery.dummyMemberType(callCode, f.canonicalName, pathSep) - case xs => - logger.debug( - s"Unhandled field structure ${xs.map(x => (x.label, x.code)).mkString(",")} @ ${debugLocation(fa)}" - ) - wrapName("") - end match - end getFieldName - - protected def visitCallAssignedToLiteral(c: Call, l: Literal): Set[String] = - if c.name.equals(Operators.indexAccess) then - // For now, we will just handle this on a very basic level - c.argumentOut.l match - case List(_: Identifier, _: Literal) => - indexAccessToCollectionVar(c).map(cv => - symbolTable.append(cv, getLiteralType(l)) - ).getOrElse(Set.empty) - case List(_: Identifier, idx: Identifier) if symbolTable.contains(idx) => - // Imprecise but sound! - indexAccessToCollectionVar(c).map(cv => - symbolTable.append(cv, symbolTable.get(idx)) - ).getOrElse(Set.empty) - case List(i: Identifier, c: Call) => - // This is an expensive level of precision to support - symbolTable.append(CollectionVar(i.name, "*"), getTypesFromCall(c)) - case List(c: Call, l: Literal) => assignTypesToCall(c, getLiteralType(l)) - case xs => - logger.debug( - s"Unhandled index access point assigned to literal ${xs.map(x => - (x.label, x.code) - ).mkString(",")} @ ${debugLocation(c)}" - ) - Set.empty - else if c.name.equals(Operators.fieldAccess) then - val fa = new FieldAccess(c) - val fieldName = getFieldName(fa) - associateTypes(LocalVar(fieldName), fa, getLiteralType(l)) - else - logger.debug( - s"Unhandled call assigned to literal point ${c.name} @ ${debugLocation(c)}" - ) - Set.empty - - /** Handles a call operation assigned to a method/function pointer. - */ - protected def visitCallAssignedToMethodRef(c: Call, m: MethodRef): Set[String] = - assignTypesToCall(c, Set(m.methodFullName)) - - /** Generates an identifier for collection/index-access operations in the symbol table. - */ - protected def indexAccessToCollectionVar(c: Call): Option[CollectionVar] = - def callName(x: Call) = - if x.name.equals(Operators.fieldAccess) then - getFieldName(new FieldAccess(x)) - else if x.name.equals(Operators.indexAccess) then - indexAccessToCollectionVar(x) - .map(cv => s"${cv.identifier}[${cv.idx}]") - .getOrElse(XTypeRecovery.DummyIndexAccess) - else x.name - - Option(c.argumentOut.l match - case List(i: Identifier, idx: Literal) => CollectionVar(i.name, idx.code) - case List(i: Identifier, idx: Identifier) => CollectionVar(i.name, idx.code) - case List(c: Call, idx: Call) => CollectionVar(callName(c), callName(idx)) - case List(c: Call, idx: Literal) => CollectionVar(callName(c), idx.code) - case List(c: Call, idx: Identifier) => CollectionVar(callName(c), idx.code) - case xs => - logger.debug( - s"Unhandled index access ${xs.map(x => (x.label, x.code)).mkString(",")} @ ${debugLocation(c)}" - ) - null + logger.debug(s"Unhandled operation $x (${c.code}) @ ${debugLocation(c)}"); Set.empty + + /** Visits an identifier being assigned to a constructor and attempts to speculate the constructor + * path. + */ + protected def visitIdentifierAssignedToConstructor(i: Identifier, c: Call): Set[String] = + val constructorPaths = + symbolTable.get(c).map(t => t.concat(s"$pathSep${Defines.ConstructorMethodName}")) + associateTypes(i, constructorPaths) + + /** Visits an identifier being assigned to a call's return value. + */ + protected def visitIdentifierAssignedToCallRetVal(i: Identifier, c: Call): Set[String] = + if symbolTable.contains(c) then + val callReturns = methodReturnValues(symbolTable.get(c).toSeq) + associateTypes(i, callReturns) + else if c.argument.exists(_.argumentIndex == 0) then + val callFullNames = (c.argument(0) match + case i: Identifier if symbolTable.contains(LocalVar(i.name)) => + symbolTable.get(LocalVar(i.name)) + case i: Identifier if symbolTable.contains(CallAlias(i.name)) => + symbolTable.get(CallAlias(i.name)) + case _ => Set.empty + ).map(_.concat(s"$pathSep${c.name}")).toSeq + val callReturns = methodReturnValues(callFullNames) + associateTypes(i, callReturns) + else + // Assign dummy value + associateTypes(i, Set(s"${c.name}$pathSep${XTypeRecovery.DummyReturnType}")) + + /** Will attempt to find the return values of a method if in the CPG, otherwise will give a dummy + * value. + */ + protected def methodReturnValues(methodFullNames: Seq[String]): Set[String] = + val rs = cpg.method + .fullNameExact(methodFullNames*) + .methodReturn + .flatMap(mr => mr.typeFullName +: mr.dynamicTypeHintFullName) + .filterNot(_.equals("ANY")) + .toSet + if rs.isEmpty then + methodFullNames.map(_.concat(s"$pathSep${XTypeRecovery.DummyReturnType}")).toSet + else rs + + /** Will handle literal value assignments. Override if special handling is required. + */ + protected def visitIdentifierAssignedToLiteral(i: Identifier, l: Literal): Set[String] = + associateTypes(i, getLiteralType(l)) + + /** Not all frontends populate typeFullName for literals so we allow this to be + * overridden. + */ + protected def getLiteralType(l: Literal): Set[String] = Set(l.typeFullName) + + /** Will handle an identifier holding a function pointer. + */ + protected def visitIdentifierAssignedToMethodRef( + i: Identifier, + m: MethodRef, + rec: Option[String] = None + ): Set[String] = + symbolTable.append(CallAlias(i.name, rec), Set(m.methodFullName)) + + /** Will handle an identifier holding a type pointer. + */ + protected def visitIdentifierAssignedToTypeRef( + i: Identifier, + t: TypeRef, + rec: Option[String] = None + ): Set[String] = + symbolTable.append(CallAlias(i.name, rec), Set(t.typeFullName)) + + /** Visits a call assigned to an identifier. This is often when there are operators involved. + */ + protected def visitCallAssignedToIdentifier(c: Call, i: Identifier): Set[String] = + val rhsTypes = symbolTable.get(i) + assignTypesToCall(c, rhsTypes) + + /** Visits a call assigned to the return value of a call. This is often when there are operators + * involved. + */ + protected def visitCallAssignedToCall(x: Call, y: Call): Set[String] = + assignTypesToCall(x, getTypesFromCall(y)) + + /** Given a call operation, will attempt to retrieve types from it. + */ + protected def getTypesFromCall(c: Call): Set[String] = c.name match + case Operators.fieldAccess => symbolTable.get(LocalVar(getFieldName(new FieldAccess(c)))) + case _ if symbolTable.contains(c) => symbolTable.get(c) + case Operators.indexAccess => getIndexAccessTypes(c) + case n => + logger.debug(s"Unknown RHS call type '$n' @ ${debugLocation(c)}") + Set.empty[String] + + /** Given a LHS call, will retrieve its symbol to the given types. + */ + protected def assignTypesToCall(x: Call, types: Set[String]): Set[String] = + if types.nonEmpty then + getSymbolFromCall(x) match + case (lhs, globalKeys) if globalKeys.nonEmpty => + globalKeys.foreach { (fieldVar: FieldPath) => + persistMemberWithTypeDecl( + fieldVar.compUnitFullName, + fieldVar.identifier, + types + ) + } + symbolTable.append(lhs, types) + case (lhs, _) => symbolTable.append(lhs, types) + else Set.empty + + /** Will attempt to retrieve index access types otherwise will return dummy value. + */ + protected def getIndexAccessTypes(ia: Call): Set[String] = indexAccessToCollectionVar(ia) match + case Some(cVar) if symbolTable.contains(cVar) => + symbolTable.get(cVar) + case Some(cVar) if symbolTable.contains(LocalVar(cVar.identifier)) => + symbolTable.get(LocalVar(cVar.identifier)).map(x => + s"$x$pathSep${XTypeRecovery.DummyIndexAccess}" ) - end indexAccessToCollectionVar - - /** Will handle an identifier being assigned to a field value. - */ - protected def visitIdentifierAssignedToFieldLoad(i: Identifier, fa: FieldAccess): Set[String] = + case _ => Set.empty + + /** Convenience class for transporting field names. + * @param compUnitFullName + * qualified path to base type holding the member. + * @param identifier + * the member name. + */ + case class FieldPath(compUnitFullName: String, identifier: String) + + /** Tries to identify the underlying symbol from the call operation as it is used on the LHS of an + * assignment. The second element is a list of any associated global keys if applicable. + */ + protected def getSymbolFromCall(c: Call): (LocalKey, Set[FieldPath]) = c.name match + case Operators.fieldAccess => + val fa = new FieldAccess(c) val fieldName = getFieldName(fa) - fa.argumentOut.l match - case ::(base: Identifier, ::(fi: FieldIdentifier, _)) - if symbolTable.contains(LocalVar(base.name)) => - // Get field from global table if referenced as a variable - val localTypes = symbolTable.get(LocalVar(base.name)) - associateInterproceduralTypes(i, base, fi, fieldName, localTypes) - case ::(base: Identifier, ::(fi: FieldIdentifier, _)) - if symbolTable.contains(LocalVar(fieldName)) => - val localTypes = symbolTable.get(LocalVar(fieldName)) - associateInterproceduralTypes(i, base, fi, fieldName, localTypes) - case ::(base: Identifier, ::(fi: FieldIdentifier, _)) => - val dummyTypes = Set(s"$fieldName$pathSep${XTypeRecovery.DummyReturnType}") - associateInterproceduralTypes(i, base, fi, fieldName, dummyTypes) - case ::(c: Call, ::(fi: FieldIdentifier, _)) if c.name.equals(Operators.fieldAccess) => - val baseName = getFieldName(new FieldAccess(c)) - // Build type regardless of length - // TODO: This is more prone to giving dummy values as it does not do global look-ups - // but this is okay for now - val buf = mutable.ArrayBuffer.empty[String] - for segment <- baseName.split(pathSep) ++ Array(fi.canonicalName) do - val types = - if buf.isEmpty then symbolTable.get(LocalVar(segment)) - else - buf.flatMap(t => symbolTable.get(LocalVar(s"$t$pathSep$segment"))).toSet - if types.nonEmpty then - buf.clear() - buf.addAll(types) - else - val bufCopy = Array.from(buf) - buf.clear() - bufCopy.foreach { - case t if isConstructor(segment) => buf.addOne(s"$t$pathSep$segment") - case t => buf.addOne(XTypeRecovery.dummyMemberType(t, segment, pathSep)) - } - associateTypes(i, buf.toSet) - case ::(call: Call, ::(fi: FieldIdentifier, _)) => - assignTypesToCall( - call, - Set(fieldName.stripSuffix( - s"${XTypeRecovery.DummyMemberLoad}$pathSep${fi.canonicalName}" - )) - ) - case _ => - logger.debug( - s"Unable to assign identifier '${i.name}' to field load '$fieldName' @ ${debugLocation(i)}" - ) - Set.empty - end match - end visitIdentifierAssignedToFieldLoad - - /** Visits an identifier being assigned to the result of an index access operation. - */ - protected def visitIdentifierAssignedToIndexAccess(i: Identifier, c: Call): Set[String] = - associateTypes(i, getTypesFromCall(c)) - - /** Visits an identifier that is the target of a cast operation. - */ - protected def visitIdentifierAssignedToCast(i: Identifier, c: Call): Set[String] = - associateTypes(i, (c.typeFullName +: c.dynamicTypeHintFullName).filterNot(_ == "ANY").toSet) - - protected def getFieldBaseType(base: Identifier, fi: FieldIdentifier): Set[String] = - getFieldBaseType(base.name, fi.canonicalName) - - protected def getFieldBaseType(baseName: String, fieldName: String): Set[String] = - symbolTable - .get(LocalVar(baseName)) - .flatMap(t => typeDeclIterator(t).member.nameExact(fieldName)) - .typeFullNameNot("ANY") - .flatMap(m => m.typeFullName +: m.dynamicTypeHintFullName) - .toSet - - protected def visitReturns(ret: Return): Unit = - val m = ret.method - val existingTypes = mutable.HashSet.from( - (m.methodReturn.typeFullName +: m.methodReturn.dynamicTypeHintFullName) - .filterNot(_ == "ANY") + (LocalVar(fieldName), getFieldParents(fa).map(fp => FieldPath(fp, fieldName))) + case Operators.indexAccess => + (indexAccessToCollectionVar(c).getOrElse(LocalVar(c.name)), Set.empty) + case x => + logger.debug(s"Using default LHS call name '$x' @ ${debugLocation(c)}") + (LocalVar(c.name), Set.empty) + + /** Extracts a string representation of the name of the field within this field access. + */ + protected def getFieldName(fa: FieldAccess, prefix: String = "", suffix: String = ""): String = + def wrapName(n: String) = + val sb = new mutable.StringBuilder() + if prefix.nonEmpty then sb.append(s"$prefix$pathSep") + sb.append(n) + if suffix.nonEmpty then sb.append(s"$pathSep$suffix") + sb.toString() + + lazy val typesFromBaseCall = fa.argumentOut.headOption match + case Some(call: Call) => getTypesFromCall(call) + case _ => Set.empty[String] + + fa.argumentOut.l match + case ::(i: Identifier, ::(f: FieldIdentifier, _)) if i.name.matches("(self|this)") => + wrapName(f.canonicalName) + case ::(i: Identifier, ::(f: FieldIdentifier, _)) => + wrapName(s"${i.name}$pathSep${f.canonicalName}") + case ::(c: Call, ::(f: FieldIdentifier, _)) if c.name.equals(Operators.fieldAccess) => + wrapName(getFieldName(new FieldAccess(c), suffix = f.canonicalName)) + case ::(_: Call, ::(f: FieldIdentifier, _)) if typesFromBaseCall.nonEmpty => + // TODO: Handle this case better + wrapName(s"${typesFromBaseCall.head}$pathSep${f.canonicalName}") + case ::(f: FieldIdentifier, ::(c: Call, _)) if c.name.equals(Operators.fieldAccess) => + wrapName(getFieldName(new FieldAccess(c), prefix = f.canonicalName)) + case ::(c: Call, ::(f: FieldIdentifier, _)) => + // TODO: Handle this case better + val callCode = + if c.code.contains("(") then c.code.substring(c.code.indexOf("(")) else c.code + XTypeRecovery.dummyMemberType(callCode, f.canonicalName, pathSep) + case xs => + logger.debug( + s"Unhandled field structure ${xs.map(x => (x.label, x.code)).mkString(",")} @ ${debugLocation(fa)}" + ) + wrapName("") + end match + end getFieldName + + protected def visitCallAssignedToLiteral(c: Call, l: Literal): Set[String] = + if c.name.equals(Operators.indexAccess) then + // For now, we will just handle this on a very basic level + c.argumentOut.l match + case List(_: Identifier, _: Literal) => + indexAccessToCollectionVar(c).map(cv => + symbolTable.append(cv, getLiteralType(l)) + ).getOrElse(Set.empty) + case List(_: Identifier, idx: Identifier) if symbolTable.contains(idx) => + // Imprecise but sound! + indexAccessToCollectionVar(c).map(cv => + symbolTable.append(cv, symbolTable.get(idx)) + ).getOrElse(Set.empty) + case List(i: Identifier, c: Call) => + // This is an expensive level of precision to support + symbolTable.append(CollectionVar(i.name, "*"), getTypesFromCall(c)) + case List(c: Call, l: Literal) => assignTypesToCall(c, getLiteralType(l)) + case xs => + logger.debug( + s"Unhandled index access point assigned to literal ${xs.map(x => + (x.label, x.code) + ).mkString(",")} @ ${debugLocation(c)}" + ) + Set.empty + else if c.name.equals(Operators.fieldAccess) then + val fa = new FieldAccess(c) + val fieldName = getFieldName(fa) + associateTypes(LocalVar(fieldName), fa, getLiteralType(l)) + else + logger.debug( + s"Unhandled call assigned to literal point ${c.name} @ ${debugLocation(c)}" ) - @tailrec - def extractTypes(xs: List[CfgNode]): Set[String] = xs match - case ::(head: Literal, Nil) if head.typeFullName != "ANY" => - Set(head.typeFullName) - case ::(head: Call, Nil) if head.name == Operators.fieldAccess => - val fieldAccess = new FieldAccess(head) - val (sym, ts) = getSymbolFromCall(fieldAccess) - val cpgTypes = cpg.typeDecl - .fullNameExact(ts.map(_.compUnitFullName).toSeq*) - .member - .nameExact(sym.identifier) - .flatMap(m => m.typeFullName +: m.dynamicTypeHintFullName) - .filterNot { x => x == "ANY" || x == "this" } - .toSet - if cpgTypes.nonEmpty then cpgTypes - else symbolTable.get(sym) - case ::(head: Call, Nil) if symbolTable.contains(head) => - val callPaths = symbolTable.get(head) - val returnValues = methodReturnValues(callPaths.toSeq) - if returnValues.isEmpty then - callPaths.map(c => s"$c$pathSep${XTypeRecovery.DummyReturnType}") + Set.empty + + /** Handles a call operation assigned to a method/function pointer. + */ + protected def visitCallAssignedToMethodRef(c: Call, m: MethodRef): Set[String] = + assignTypesToCall(c, Set(m.methodFullName)) + + /** Generates an identifier for collection/index-access operations in the symbol table. + */ + protected def indexAccessToCollectionVar(c: Call): Option[CollectionVar] = + def callName(x: Call) = + if x.name.equals(Operators.fieldAccess) then + getFieldName(new FieldAccess(x)) + else if x.name.equals(Operators.indexAccess) then + indexAccessToCollectionVar(x) + .map(cv => s"${cv.identifier}[${cv.idx}]") + .getOrElse(XTypeRecovery.DummyIndexAccess) + else x.name + + Option(c.argumentOut.l match + case List(i: Identifier, idx: Literal) => CollectionVar(i.name, idx.code) + case List(i: Identifier, idx: Identifier) => CollectionVar(i.name, idx.code) + case List(c: Call, idx: Call) => CollectionVar(callName(c), callName(idx)) + case List(c: Call, idx: Literal) => CollectionVar(callName(c), idx.code) + case List(c: Call, idx: Identifier) => CollectionVar(callName(c), idx.code) + case xs => + logger.debug( + s"Unhandled index access ${xs.map(x => (x.label, x.code)).mkString(",")} @ ${debugLocation(c)}" + ) + null + ) + end indexAccessToCollectionVar + + /** Will handle an identifier being assigned to a field value. + */ + protected def visitIdentifierAssignedToFieldLoad(i: Identifier, fa: FieldAccess): Set[String] = + val fieldName = getFieldName(fa) + fa.argumentOut.l match + case ::(base: Identifier, ::(fi: FieldIdentifier, _)) + if symbolTable.contains(LocalVar(base.name)) => + // Get field from global table if referenced as a variable + val localTypes = symbolTable.get(LocalVar(base.name)) + associateInterproceduralTypes(i, base, fi, fieldName, localTypes) + case ::(base: Identifier, ::(fi: FieldIdentifier, _)) + if symbolTable.contains(LocalVar(fieldName)) => + val localTypes = symbolTable.get(LocalVar(fieldName)) + associateInterproceduralTypes(i, base, fi, fieldName, localTypes) + case ::(base: Identifier, ::(fi: FieldIdentifier, _)) => + val dummyTypes = Set(s"$fieldName$pathSep${XTypeRecovery.DummyReturnType}") + associateInterproceduralTypes(i, base, fi, fieldName, dummyTypes) + case ::(c: Call, ::(fi: FieldIdentifier, _)) if c.name.equals(Operators.fieldAccess) => + val baseName = getFieldName(new FieldAccess(c)) + // Build type regardless of length + // TODO: This is more prone to giving dummy values as it does not do global look-ups + // but this is okay for now + val buf = mutable.ArrayBuffer.empty[String] + for segment <- baseName.split(pathSep) ++ Array(fi.canonicalName) do + val types = + if buf.isEmpty then symbolTable.get(LocalVar(segment)) else - returnValues - case ::(head: Call, Nil) if head.argumentOut.headOption.exists(symbolTable.contains) => - symbolTable - .get(head.argumentOut.head) - .map(t => - Seq(t, head.name, XTypeRecovery.DummyReturnType).mkString(pathSep.toString) - ) - case ::(identifier: Identifier, Nil) if symbolTable.contains(identifier) => - symbolTable.get(identifier) - case ::(head: Call, Nil) => - extractTypes(head.argument.l) - case _ => Set.empty - val returnTypes = extractTypes(ret.argumentOut.l) - existingTypes.addAll(returnTypes) - builder.setNodeProperty( - ret.method.methodReturn, - PropertyNames.DYNAMIC_TYPE_HINT_FULL_NAME, - existingTypes - ) - end visitReturns - - /** Using an entry from the symbol table, will queue the CPG modification to persist the - * recovered type information. - */ - protected def setTypeInformation(): Unit = - cu.ast - .collect { - case n: Local => n - case n: Call => n - case n: Expression => n - case n: MethodParameterIn if state.isFinalIteration => n - case n: MethodReturn if state.isFinalIteration => n - } - .foreach { - case x: Local if symbolTable.contains(x) => - storeNodeTypeInfo(x, symbolTable.get(x).toSeq) - case x: MethodParameterIn => setTypeFromTypeHints(x) - case x: MethodReturn => setTypeFromTypeHints(x) - case x: Identifier if symbolTable.contains(x) => - setTypeInformationForRecCall(x, x.inCall.headOption, x.inCall.argument.l) - case x: Call => - if symbolTable.contains(x) then - val typs = - if state.config.enabledDummyTypes then symbolTable.get(x).toSeq - else symbolTable.get(x).filterNot(XTypeRecovery.isDummyType).toSeq - storeCallTypeInfo(x, typs) - else if x.argument.headOption.exists(symbolTable.contains) then - setTypeInformationForRecCall(x, Option(x), x.argument.l) - else if !x.name.startsWith("<") && !x.code.contains( - "require" - ) && !x.code.contains("this") - then - storeCallTypeInfo(x, Seq(x.code.takeWhile(_ != '('))) - case x: Identifier - if symbolTable.contains(CallAlias(x.name)) && x.inCall.nonEmpty => - setTypeInformationForRecCall(x, x.inCall.headOption, x.inCall.argument.l) - case _ => - } - // Set types in an atomic way - newTypesForMembers.foreach { case (m, ts) => storeDefaultTypeInfo(m, ts.toSeq) } - end setTypeInformation - - protected def createCallFromIdentifierTypeFullName( - typeFullName: String, - callName: String - ): String = - s"$typeFullName$pathSep$callName" - - /** Sets type information for a receiver/call pattern. - */ - private def setTypeInformationForRecCall(x: AstNode, n: Option[Call], ms: List[AstNode]): Unit = - (n, ms) match - // Case 1: 'call' is an assignment from some dynamic dispatch call - case (Some(call: Call), ::(i: Identifier, ::(c: Call, _))) - if call.name == Operators.assignment => - setTypeForIdentifierAssignedToCall(call, i, c) - // Case 1: 'call' is an assignment from some other data structure - case (Some(call: Call), ::(i: Identifier, _)) if call.name == Operators.assignment => - setTypeForIdentifierAssignedToDefault(call, i) - // Case 2: 'i' is the receiver of 'call' - case (Some(call: Call), ::(i: Identifier, _)) if call.name != Operators.fieldAccess => - setTypeForDynamicDispatchCall(call, i) - // Case 3: 'i' is the receiver for a field access on member 'f' - case (Some(fieldAccess: Call), ::(i: Identifier, ::(f: FieldIdentifier, _))) - if fieldAccess.name == Operators.fieldAccess => - setTypeForFieldAccess(new FieldAccess(fieldAccess), i, f) - case _ => - // Handle the node itself - x match - case c: Call if c.name.startsWith(" - case _ => persistType(x, symbolTable.get(x)) - end setTypeInformationForRecCall - - protected def setTypeForFieldAccess( - fieldAccess: Call, - i: Identifier, - f: FieldIdentifier - ): Unit = - val idHints = if symbolTable.contains(i) then symbolTable.get(i) - else symbolTable.get(CallAlias(i.name)) - val callTypes = symbolTable.get(fieldAccess) - persistType(i, idHints) - persistType(fieldAccess, callTypes) - fieldAccess.astParent.iterator.isCall.headOption match - case Some(callFromFieldName) if symbolTable.contains(callFromFieldName) => - persistType(callFromFieldName, symbolTable.get(callFromFieldName)) - case _ => - // This field may be a function pointer - handlePotentialFunctionPointer(fieldAccess, idHints, f.canonicalName, Option(i.name)) - - protected def setTypeForDynamicDispatchCall(call: Call, i: Identifier): Unit = - val idHints = symbolTable.get(i) - val callTypes = symbolTable.get(call) - persistType(i, idHints) - if callTypes.isEmpty && !call.name.startsWith("") then - // For now, calls are treated as function pointers and thus the type should point to the method - persistType(call, idHints.map(t => createCallFromIdentifierTypeFullName(t, call.name))) - else - persistType(call, callTypes) - - protected def setTypeForIdentifierAssignedToDefault(call: Call, i: Identifier): Unit = - val idHints = symbolTable.get(i) - persistType(i, idHints) - persistType(call, idHints) - - protected def setTypeForIdentifierAssignedToCall(call: Call, i: Identifier, c: Call): Unit = - val idTypes = if symbolTable.contains(i) then symbolTable.get(i) - else symbolTable.get(CallAlias(i.name)) - val callTypes = symbolTable.get(c) - persistType(call, callTypes) - if idTypes.nonEmpty || callTypes.nonEmpty then - if idTypes.equals(callTypes) then - // Case 1.1: This is a function pointer or constructor - persistType(i, callTypes) + buf.flatMap(t => symbolTable.get(LocalVar(s"$t$pathSep$segment"))).toSet + if types.nonEmpty then + buf.clear() + buf.addAll(types) else - // Case 1.2: This is the return value of the function - persistType(i, idTypes) - - protected def setTypeFromTypeHints(n: StoredNode): Unit = - val types = n.getKnownTypes.filterNot(XTypeRecovery.isDummyType) - if types.nonEmpty then setTypes(n, types.toSeq) - - /** In the case this field access is a function pointer, we would want to make sure this has a - * method ref. - */ - private def handlePotentialFunctionPointer( - funcPtr: Expression, - baseTypes: Set[String], - funcName: String, - baseName: Option[String] = None - ): Unit = - // Sometimes the function identifier is an argument to the call itself as a "base". In this case we don't need - // a method ref. This happens in jssrc2cpg - if !funcPtr.astParent.iterator.collectAll[Call].exists(_.name == funcName) then - baseTypes - .map(t => if t.endsWith(funcName) then t else s"$t$pathSep$funcName") - .flatMap(cpg.method.fullNameExact) - .filterNot(m => - addedNodes.contains( - s"${funcPtr.id()}${NodeTypes.METHOD_REF}$pathSep${m.fullName}" - ) - ) - .map(m => - m -> createMethodRef( - baseName, - funcName, - m.fullName, - funcPtr.lineNumber, - funcPtr.columnNumber - ) - ) - .foreach { case (m, mRef) => - funcPtr.astParent - .filterNot( - _.astChildren.isMethodRef.exists(_.methodFullName == mRef.methodFullName) - ) - .foreach { inCall => - state.changesWereMade.compareAndSet(false, true) - integrateMethodRef(funcPtr, m, mRef, inCall) - } - } - - private def createMethodRef( - baseName: Option[String], - funcName: String, - methodFullName: String, - lineNo: Option[Integer], - columnNo: Option[Integer] - ): NewMethodRef = - NewMethodRef() - .code(s"${baseName.map(_.appended(pathSep)).getOrElse("")}$funcName") - .methodFullName(methodFullName) - .lineNumber(lineNo) - .columnNumber(columnNo) - - /** Integrate this method ref node into the CPG according to schema rules. Since we're adding - * this after the base passes, we need to add the necessary linking manually. - */ - private def integrateMethodRef( - funcPtr: Expression, - m: Method, - mRef: NewMethodRef, - inCall: AstNode - ) = - builder.addNode(mRef) - builder.addEdge(mRef, m, EdgeTypes.REF) - builder.addEdge(inCall, mRef, EdgeTypes.AST) - builder.addEdge(funcPtr.method, mRef, EdgeTypes.CONTAINS) - inCall match + val bufCopy = Array.from(buf) + buf.clear() + bufCopy.foreach { + case t if isConstructor(segment) => buf.addOne(s"$t$pathSep$segment") + case t => buf.addOne(XTypeRecovery.dummyMemberType(t, segment, pathSep)) + } + associateTypes(i, buf.toSet) + case ::(call: Call, ::(fi: FieldIdentifier, _)) => + assignTypesToCall( + call, + Set(fieldName.stripSuffix( + s"${XTypeRecovery.DummyMemberLoad}$pathSep${fi.canonicalName}" + )) + ) + case _ => + logger.debug( + s"Unable to assign identifier '${i.name}' to field load '$fieldName' @ ${debugLocation(i)}" + ) + Set.empty + end match + end visitIdentifierAssignedToFieldLoad + + /** Visits an identifier being assigned to the result of an index access operation. + */ + protected def visitIdentifierAssignedToIndexAccess(i: Identifier, c: Call): Set[String] = + associateTypes(i, getTypesFromCall(c)) + + /** Visits an identifier that is the target of a cast operation. + */ + protected def visitIdentifierAssignedToCast(i: Identifier, c: Call): Set[String] = + associateTypes(i, (c.typeFullName +: c.dynamicTypeHintFullName).filterNot(_ == "ANY").toSet) + + protected def getFieldBaseType(base: Identifier, fi: FieldIdentifier): Set[String] = + getFieldBaseType(base.name, fi.canonicalName) + + protected def getFieldBaseType(baseName: String, fieldName: String): Set[String] = + symbolTable + .get(LocalVar(baseName)) + .flatMap(t => typeDeclIterator(t).member.nameExact(fieldName)) + .typeFullNameNot("ANY") + .flatMap(m => m.typeFullName +: m.dynamicTypeHintFullName) + .toSet + + protected def visitReturns(ret: Return): Unit = + val m = ret.method + val existingTypes = mutable.HashSet.from( + (m.methodReturn.typeFullName +: m.methodReturn.dynamicTypeHintFullName) + .filterNot(_ == "ANY") + ) + @tailrec + def extractTypes(xs: List[CfgNode]): Set[String] = xs match + case ::(head: Literal, Nil) if head.typeFullName != "ANY" => + Set(head.typeFullName) + case ::(head: Call, Nil) if head.name == Operators.fieldAccess => + val fieldAccess = new FieldAccess(head) + val (sym, ts) = getSymbolFromCall(fieldAccess) + val cpgTypes = cpg.typeDecl + .fullNameExact(ts.map(_.compUnitFullName).toSeq*) + .member + .nameExact(sym.identifier) + .flatMap(m => m.typeFullName +: m.dynamicTypeHintFullName) + .filterNot { x => x == "ANY" || x == "this" } + .toSet + if cpgTypes.nonEmpty then cpgTypes + else symbolTable.get(sym) + case ::(head: Call, Nil) if symbolTable.contains(head) => + val callPaths = symbolTable.get(head) + val returnValues = methodReturnValues(callPaths.toSeq) + if returnValues.isEmpty then + callPaths.map(c => s"$c$pathSep${XTypeRecovery.DummyReturnType}") + else + returnValues + case ::(head: Call, Nil) if head.argumentOut.headOption.exists(symbolTable.contains) => + symbolTable + .get(head.argumentOut.head) + .map(t => + Seq(t, head.name, XTypeRecovery.DummyReturnType).mkString(pathSep.toString) + ) + case ::(identifier: Identifier, Nil) if symbolTable.contains(identifier) => + symbolTable.get(identifier) + case ::(head: Call, Nil) => + extractTypes(head.argument.l) + case _ => Set.empty + val returnTypes = extractTypes(ret.argumentOut.l) + existingTypes.addAll(returnTypes) + builder.setNodeProperty( + ret.method.methodReturn, + PropertyNames.DYNAMIC_TYPE_HINT_FULL_NAME, + existingTypes + ) + end visitReturns + + /** Using an entry from the symbol table, will queue the CPG modification to persist the recovered + * type information. + */ + protected def setTypeInformation(): Unit = + cu.ast + .collect { + case n: Local => n + case n: Call => n + case n: Expression => n + case n: MethodParameterIn if state.isFinalIteration => n + case n: MethodReturn if state.isFinalIteration => n + } + .foreach { + case x: Local if symbolTable.contains(x) => + storeNodeTypeInfo(x, symbolTable.get(x).toSeq) + case x: MethodParameterIn => setTypeFromTypeHints(x) + case x: MethodReturn => setTypeFromTypeHints(x) + case x: Identifier if symbolTable.contains(x) => + setTypeInformationForRecCall(x, x.inCall.headOption, x.inCall.argument.l) case x: Call => - builder.addEdge(x, mRef, EdgeTypes.ARGUMENT) - mRef.argumentIndex(x.argumentOut.size + 1) - case x => - mRef.argumentIndex(x.astChildren.size + 1) - addedNodes.add(s"${funcPtr.id()}${NodeTypes.METHOD_REF}$pathSep${mRef.methodFullName}") - - protected def persistType(x: StoredNode, types: Set[String]): Unit = - val filteredTypes = if state.config.enabledDummyTypes then types - else types.filterNot(XTypeRecovery.isDummyType) - if filteredTypes.nonEmpty then - storeNodeTypeInfo(x, filteredTypes.toSeq) - x match - case i: Identifier if symbolTable.contains(i) => - if isField(i) then persistMemberType(i, filteredTypes) - handlePotentialFunctionPointer(i, filteredTypes, i.name) - case _ => - - private def persistMemberType(i: Identifier, types: Set[String]): Unit = - getLocalMember(i) match - case Some(m) => storeNodeTypeInfo(m, types.toSeq) - case None => - - /** Type decls where member access are required need to point to the correct type decl that - * holds said members. This allows implementations to use the type names to find the correct - * type holding members. - * @param typeFullName - * the type full name. - * @return - * the type full name that has member children. - */ - protected def typeDeclIterator(typeFullName: String): Iterator[TypeDecl] = - cpg.typeDecl.fullNameExact(typeFullName) - - /** Given a type full name and member name, will persist the given types to the member. - * @param typeFullName - * the type full name. - * @param memberName - * the member name. - * @param types - * the types to associate. - */ - protected def persistMemberWithTypeDecl( - typeFullName: String, - memberName: String, - types: Set[String] - ): Unit = - typeDeclIterator(typeFullName).member.nameExact(memberName).headOption.foreach { m => - storeNodeTypeInfo(m, types.toSeq) + if symbolTable.contains(x) then + val typs = + if state.config.enabledDummyTypes then symbolTable.get(x).toSeq + else symbolTable.get(x).filterNot(XTypeRecovery.isDummyType).toSeq + storeCallTypeInfo(x, typs) + else if x.argument.headOption.exists(symbolTable.contains) then + setTypeInformationForRecCall(x, Option(x), x.argument.l) + else if !x.name.startsWith("<") && !x.code.contains( + "require" + ) && !x.code.contains("this") + then + storeCallTypeInfo(x, Seq(x.code.takeWhile(_ != '('))) + case x: Identifier + if symbolTable.contains(CallAlias(x.name)) && x.inCall.nonEmpty => + setTypeInformationForRecCall(x, x.inCall.headOption, x.inCall.argument.l) + case _ => } - - /** Given an identifier that has been determined to be a field, an attempt is made to get the - * corresponding member. This implementation follows more the way dynamic languages define - * method/type relations. - * @param i - * the identifier. - * @return - * the corresponding member, if found - */ - protected def getLocalMember(i: Identifier): Option[Member] = - typeDeclIterator(i.method.typeDecl.fullName.headOption.getOrElse(i.method.fullName)).member - .nameExact(i.name) - .headOption - - private def storeNodeTypeInfo(storedNode: StoredNode, types: Seq[String]): Unit = - lazy val existingTypes = storedNode.getKnownTypes - - if types.nonEmpty && types.toSet != existingTypes then - storedNode match - case m: Member => - // To avoid overwriting member updates, we store them elsewhere until the end - newTypesForMembers.updateWith(m) { - case Some(ts) => Some(ts ++ types) - case None => Some(types.toSet) - } - case i: Identifier => storeIdentifierTypeInfo(i, types) - case l: Local => storeLocalTypeInfo(l, types) - case c: Call if !c.name.startsWith("") => storeCallTypeInfo(c, types) - case _: Call => - case n => - state.changesWereMade.compareAndSet(false, true) - setTypes(n, types) - - protected def storeCallTypeInfo(c: Call, types: Seq[String]): Unit = - if types.nonEmpty then - state.changesWereMade.compareAndSet(false, true) - builder.setNodeProperty( - c, - PropertyNames.DYNAMIC_TYPE_HINT_FULL_NAME, - (c.dynamicTypeHintFullName ++ types).distinct + // Set types in an atomic way + newTypesForMembers.foreach { case (m, ts) => storeDefaultTypeInfo(m, ts.toSeq) } + end setTypeInformation + + protected def createCallFromIdentifierTypeFullName( + typeFullName: String, + callName: String + ): String = + s"$typeFullName$pathSep$callName" + + /** Sets type information for a receiver/call pattern. + */ + private def setTypeInformationForRecCall(x: AstNode, n: Option[Call], ms: List[AstNode]): Unit = + (n, ms) match + // Case 1: 'call' is an assignment from some dynamic dispatch call + case (Some(call: Call), ::(i: Identifier, ::(c: Call, _))) + if call.name == Operators.assignment => + setTypeForIdentifierAssignedToCall(call, i, c) + // Case 1: 'call' is an assignment from some other data structure + case (Some(call: Call), ::(i: Identifier, _)) if call.name == Operators.assignment => + setTypeForIdentifierAssignedToDefault(call, i) + // Case 2: 'i' is the receiver of 'call' + case (Some(call: Call), ::(i: Identifier, _)) if call.name != Operators.fieldAccess => + setTypeForDynamicDispatchCall(call, i) + // Case 3: 'i' is the receiver for a field access on member 'f' + case (Some(fieldAccess: Call), ::(i: Identifier, ::(f: FieldIdentifier, _))) + if fieldAccess.name == Operators.fieldAccess => + setTypeForFieldAccess(new FieldAccess(fieldAccess), i, f) + case _ => + // Handle the node itself + x match + case c: Call if c.name.startsWith(" + case _ => persistType(x, symbolTable.get(x)) + end setTypeInformationForRecCall + + protected def setTypeForFieldAccess( + fieldAccess: Call, + i: Identifier, + f: FieldIdentifier + ): Unit = + val idHints = if symbolTable.contains(i) then symbolTable.get(i) + else symbolTable.get(CallAlias(i.name)) + val callTypes = symbolTable.get(fieldAccess) + persistType(i, idHints) + persistType(fieldAccess, callTypes) + fieldAccess.astParent.iterator.isCall.headOption match + case Some(callFromFieldName) if symbolTable.contains(callFromFieldName) => + persistType(callFromFieldName, symbolTable.get(callFromFieldName)) + case _ => + // This field may be a function pointer + handlePotentialFunctionPointer(fieldAccess, idHints, f.canonicalName, Option(i.name)) + + protected def setTypeForDynamicDispatchCall(call: Call, i: Identifier): Unit = + val idHints = symbolTable.get(i) + val callTypes = symbolTable.get(call) + persistType(i, idHints) + if callTypes.isEmpty && !call.name.startsWith("") then + // For now, calls are treated as function pointers and thus the type should point to the method + persistType(call, idHints.map(t => createCallFromIdentifierTypeFullName(t, call.name))) + else + persistType(call, callTypes) + + protected def setTypeForIdentifierAssignedToDefault(call: Call, i: Identifier): Unit = + val idHints = symbolTable.get(i) + persistType(i, idHints) + persistType(call, idHints) + + protected def setTypeForIdentifierAssignedToCall(call: Call, i: Identifier, c: Call): Unit = + val idTypes = if symbolTable.contains(i) then symbolTable.get(i) + else symbolTable.get(CallAlias(i.name)) + val callTypes = symbolTable.get(c) + persistType(call, callTypes) + if idTypes.nonEmpty || callTypes.nonEmpty then + if idTypes.equals(callTypes) then + // Case 1.1: This is a function pointer or constructor + persistType(i, callTypes) + else + // Case 1.2: This is the return value of the function + persistType(i, idTypes) + + protected def setTypeFromTypeHints(n: StoredNode): Unit = + val types = n.getKnownTypes.filterNot(XTypeRecovery.isDummyType) + if types.nonEmpty then setTypes(n, types.toSeq) + + /** In the case this field access is a function pointer, we would want to make sure this has a + * method ref. + */ + private def handlePotentialFunctionPointer( + funcPtr: Expression, + baseTypes: Set[String], + funcName: String, + baseName: Option[String] = None + ): Unit = + // Sometimes the function identifier is an argument to the call itself as a "base". In this case we don't need + // a method ref. This happens in jssrc2cpg + if !funcPtr.astParent.iterator.collectAll[Call].exists(_.name == funcName) then + baseTypes + .map(t => if t.endsWith(funcName) then t else s"$t$pathSep$funcName") + .flatMap(cpg.method.fullNameExact) + .filterNot(m => + addedNodes.contains( + s"${funcPtr.id()}${NodeTypes.METHOD_REF}$pathSep${m.fullName}" + ) ) + .map(m => + m -> createMethodRef( + baseName, + funcName, + m.fullName, + funcPtr.lineNumber, + funcPtr.columnNumber + ) + ) + .foreach { case (m, mRef) => + funcPtr.astParent + .filterNot( + _.astChildren.isMethodRef.exists(_.methodFullName == mRef.methodFullName) + ) + .foreach { inCall => + state.changesWereMade.compareAndSet(false, true) + integrateMethodRef(funcPtr, m, mRef, inCall) + } + } - /** Allows one to modify the types assigned to identifiers. - */ - protected def storeIdentifierTypeInfo(i: Identifier, types: Seq[String]): Unit = - storeDefaultTypeInfo(i, types) - - /** Allows one to modify the types assigned to nodes otherwise. - */ - protected def storeDefaultTypeInfo(n: StoredNode, types: Seq[String]): Unit = - if types.toSet != n.getKnownTypes then + private def createMethodRef( + baseName: Option[String], + funcName: String, + methodFullName: String, + lineNo: Option[Integer], + columnNo: Option[Integer] + ): NewMethodRef = + NewMethodRef() + .code(s"${baseName.map(_.appended(pathSep)).getOrElse("")}$funcName") + .methodFullName(methodFullName) + .lineNumber(lineNo) + .columnNumber(columnNo) + + /** Integrate this method ref node into the CPG according to schema rules. Since we're adding this + * after the base passes, we need to add the necessary linking manually. + */ + private def integrateMethodRef( + funcPtr: Expression, + m: Method, + mRef: NewMethodRef, + inCall: AstNode + ) = + builder.addNode(mRef) + builder.addEdge(mRef, m, EdgeTypes.REF) + builder.addEdge(inCall, mRef, EdgeTypes.AST) + builder.addEdge(funcPtr.method, mRef, EdgeTypes.CONTAINS) + inCall match + case x: Call => + builder.addEdge(x, mRef, EdgeTypes.ARGUMENT) + mRef.argumentIndex(x.argumentOut.size + 1) + case x => + mRef.argumentIndex(x.astChildren.size + 1) + addedNodes.add(s"${funcPtr.id()}${NodeTypes.METHOD_REF}$pathSep${mRef.methodFullName}") + + protected def persistType(x: StoredNode, types: Set[String]): Unit = + val filteredTypes = if state.config.enabledDummyTypes then types + else types.filterNot(XTypeRecovery.isDummyType) + if filteredTypes.nonEmpty then + storeNodeTypeInfo(x, filteredTypes.toSeq) + x match + case i: Identifier if symbolTable.contains(i) => + if isField(i) then persistMemberType(i, filteredTypes) + handlePotentialFunctionPointer(i, filteredTypes, i.name) + case _ => + + private def persistMemberType(i: Identifier, types: Set[String]): Unit = + getLocalMember(i) match + case Some(m) => storeNodeTypeInfo(m, types.toSeq) + case None => + + /** Type decls where member access are required need to point to the correct type decl that holds + * said members. This allows implementations to use the type names to find the correct type + * holding members. + * @param typeFullName + * the type full name. + * @return + * the type full name that has member children. + */ + protected def typeDeclIterator(typeFullName: String): Iterator[TypeDecl] = + cpg.typeDecl.fullNameExact(typeFullName) + + /** Given a type full name and member name, will persist the given types to the member. + * @param typeFullName + * the type full name. + * @param memberName + * the member name. + * @param types + * the types to associate. + */ + protected def persistMemberWithTypeDecl( + typeFullName: String, + memberName: String, + types: Set[String] + ): Unit = + typeDeclIterator(typeFullName).member.nameExact(memberName).headOption.foreach { m => + storeNodeTypeInfo(m, types.toSeq) + } + + /** Given an identifier that has been determined to be a field, an attempt is made to get the + * corresponding member. This implementation follows more the way dynamic languages define + * method/type relations. + * @param i + * the identifier. + * @return + * the corresponding member, if found + */ + protected def getLocalMember(i: Identifier): Option[Member] = + typeDeclIterator(i.method.typeDecl.fullName.headOption.getOrElse(i.method.fullName)).member + .nameExact(i.name) + .headOption + + private def storeNodeTypeInfo(storedNode: StoredNode, types: Seq[String]): Unit = + lazy val existingTypes = storedNode.getKnownTypes + + if types.nonEmpty && types.toSet != existingTypes then + storedNode match + case m: Member => + // To avoid overwriting member updates, we store them elsewhere until the end + newTypesForMembers.updateWith(m) { + case Some(ts) => Some(ts ++ types) + case None => Some(types.toSet) + } + case i: Identifier => storeIdentifierTypeInfo(i, types) + case l: Local => storeLocalTypeInfo(l, types) + case c: Call if !c.name.startsWith("") => storeCallTypeInfo(c, types) + case _: Call => + case n => state.changesWereMade.compareAndSet(false, true) - setTypes( - n, - (n.property(PropertyNames.DYNAMIC_TYPE_HINT_FULL_NAME, Seq.empty) ++ types).distinct - ) + setTypes(n, types) + + protected def storeCallTypeInfo(c: Call, types: Seq[String]): Unit = + if types.nonEmpty then + state.changesWereMade.compareAndSet(false, true) + builder.setNodeProperty( + c, + PropertyNames.DYNAMIC_TYPE_HINT_FULL_NAME, + (c.dynamicTypeHintFullName ++ types).distinct + ) - /** If there is only 1 type hint then this is set to the `typeFullName` property and - * `dynamicTypeHintFullName` is cleared. If not then `dynamicTypeHintFullName` is set to the - * types. - */ - protected def setTypes(n: StoredNode, types: Seq[String]): Unit = - if types.size == 1 then - builder.setNodeProperty(n, PropertyNames.TYPE_FULL_NAME, types.head) - builder.setNodeProperty(n, PropertyNames.DYNAMIC_TYPE_HINT_FULL_NAME, Seq.empty) - else if types.size == 2 && types.last.nonEmpty && types.last != "null" then - builder.setNodeProperty(n, PropertyNames.TYPE_FULL_NAME, types.last) - builder.setNodeProperty(n, PropertyNames.DYNAMIC_TYPE_HINT_FULL_NAME, Seq(types.head)) - else builder.setNodeProperty(n, PropertyNames.DYNAMIC_TYPE_HINT_FULL_NAME, types) - - /** Allows one to modify the types assigned to locals. - */ - protected def storeLocalTypeInfo(l: Local, types: Seq[String]): Unit = - storeDefaultTypeInfo( - l, - if state.config.enabledDummyTypes then types - else types.filterNot(XTypeRecovery.isDummyType) + /** Allows one to modify the types assigned to identifiers. + */ + protected def storeIdentifierTypeInfo(i: Identifier, types: Seq[String]): Unit = + storeDefaultTypeInfo(i, types) + + /** Allows one to modify the types assigned to nodes otherwise. + */ + protected def storeDefaultTypeInfo(n: StoredNode, types: Seq[String]): Unit = + if types.toSet != n.getKnownTypes then + state.changesWereMade.compareAndSet(false, true) + setTypes( + n, + (n.property(PropertyNames.DYNAMIC_TYPE_HINT_FULL_NAME, Seq.empty) ++ types).distinct ) - /** Allows an implementation to perform an operation once type persistence is complete. - */ - protected def postSetTypeInformation(): Unit = {} + /** If there is only 1 type hint then this is set to the `typeFullName` property and + * `dynamicTypeHintFullName` is cleared. If not then `dynamicTypeHintFullName` is set to the + * types. + */ + protected def setTypes(n: StoredNode, types: Seq[String]): Unit = + if types.size == 1 then + builder.setNodeProperty(n, PropertyNames.TYPE_FULL_NAME, types.head) + builder.setNodeProperty(n, PropertyNames.DYNAMIC_TYPE_HINT_FULL_NAME, Seq.empty) + else if types.size == 2 && types.last.nonEmpty && types.last != "null" then + builder.setNodeProperty(n, PropertyNames.TYPE_FULL_NAME, types.last) + builder.setNodeProperty(n, PropertyNames.DYNAMIC_TYPE_HINT_FULL_NAME, Seq(types.head)) + else builder.setNodeProperty(n, PropertyNames.DYNAMIC_TYPE_HINT_FULL_NAME, types) + + /** Allows one to modify the types assigned to locals. + */ + protected def storeLocalTypeInfo(l: Local, types: Seq[String]): Unit = + storeDefaultTypeInfo( + l, + if state.config.enabledDummyTypes then types + else types.filterNot(XTypeRecovery.isDummyType) + ) - private val unknownTypePattern = s"(i?)(UNKNOWN|ANY|${Defines.UnresolvedNamespace}).*".r + /** Allows an implementation to perform an operation once type persistence is complete. + */ + protected def postSetTypeInformation(): Unit = {} - // The below are convenience calls for accessing type properties, one day when this pass uses `Tag` nodes instead of - // the symbol table then perhaps this would work out better - implicit class AllNodeTypesFromNodeExt(x: StoredNode): - def allTypes: Iterator[String] = - (x.property(PropertyNames.TYPE_FULL_NAME, "ANY") +: x.property( - PropertyNames.DYNAMIC_TYPE_HINT_FULL_NAME, - Seq.empty - )).iterator + private val unknownTypePattern = s"(i?)(UNKNOWN|ANY|${Defines.UnresolvedNamespace}).*".r + + // The below are convenience calls for accessing type properties, one day when this pass uses `Tag` nodes instead of + // the symbol table then perhaps this would work out better + implicit class AllNodeTypesFromNodeExt(x: StoredNode): + def allTypes: Iterator[String] = + (x.property(PropertyNames.TYPE_FULL_NAME, "ANY") +: x.property( + PropertyNames.DYNAMIC_TYPE_HINT_FULL_NAME, + Seq.empty + )).iterator - def getKnownTypes: Set[String] = - x.allTypes.toSet.filterNot(unknownTypePattern.matches) + def getKnownTypes: Set[String] = + x.allTypes.toSet.filterNot(unknownTypePattern.matches) - implicit class AllNodeTypesFromIteratorExt(x: Iterator[StoredNode]): - def allTypes: Iterator[String] = x.flatMap(_.allTypes) + implicit class AllNodeTypesFromIteratorExt(x: Iterator[StoredNode]): + def allTypes: Iterator[String] = x.flatMap(_.allTypes) - def getKnownTypes: Set[String] = - x.allTypes.toSet.filterNot(unknownTypePattern.matches) + def getKnownTypes: Set[String] = + x.allTypes.toSet.filterNot(unknownTypePattern.matches) end RecoverForXCompilationUnit diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/taggers/CdxPass.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/taggers/CdxPass.scala index 86ed9893..4616cd99 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/taggers/CdxPass.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/taggers/CdxPass.scala @@ -16,342 +16,343 @@ import scala.io.Source */ class CdxPass(atom: Cpg) extends CpgPass(atom): - val language: String = atom.metaData.language.head + val language: String = atom.metaData.language.head - // Number of tags needed - private val TAGS_COUNT: Int = 2 + // Number of tags needed + private val TAGS_COUNT: Int = 2 - // Number of dots to use in the package namespace - // Example: org.apache.logging.* would be used for tagging purposes - private val PKG_NS_SIZE: Int = 3 + // Number of dots to use in the package namespace + // Example: org.apache.logging.* would be used for tagging purposes + private val PKG_NS_SIZE: Int = 3 - // tags list as a seed - private val keywords: List[String] = Source.fromResource("tags-vocab.txt").getLines.toList + // tags list as a seed + private val keywords: List[String] = Source.fromResource("tags-vocab.txt").getLines.toList - // FIXME: Replace these with semantic fingerprints - private def JS_REQUEST_PATTERNS = - Array( - "(?s)(?i).*(req|ctx|context)\\.(originalUrl|path|protocol|route|secure|signedCookies|stale|subdomains|xhr|app|pipe|file|files|baseUrl|fresh|hostname|ip|url|ips|method|body|param|params|query|cookies|request).*" - ) + // FIXME: Replace these with semantic fingerprints + private def JS_REQUEST_PATTERNS = + Array( + "(?s)(?i).*(req|ctx|context)\\.(originalUrl|path|protocol|route|secure|signedCookies|stale|subdomains|xhr|app|pipe|file|files|baseUrl|fresh|hostname|ip|url|ips|method|body|param|params|query|cookies|request).*" + ) - private def JS_RESPONSE_PATTERNS = - Array( - "(?s)(?i).*(res|ctx|context)\\.(append|attachment|body|cookie|download|end|format|json|jsonp|links|location|redirect|render|send|sendFile|sendStatus|set|vary).*", - "(?s)(?i).*res\\.(set|writeHead|setHeader).*", - "(?s)(?i).*(db|dao|mongo|mongoclient).*", - "(?s)(?i).*(\\s|\\.)(list|create|upload|delete|execute|command|invoke|submit|send)" - ) + private def JS_RESPONSE_PATTERNS = + Array( + "(?s)(?i).*(res|ctx|context)\\.(append|attachment|body|cookie|download|end|format|json|jsonp|links|location|redirect|render|send|sendFile|sendStatus|set|vary).*", + "(?s)(?i).*res\\.(set|writeHead|setHeader).*", + "(?s)(?i).*(db|dao|mongo|mongoclient).*", + "(?s)(?i).*(\\s|\\.)(list|create|upload|delete|execute|command|invoke|submit|send)" + ) - private def PY_REQUEST_PATTERNS = Array(".*views.py:.*") + private def PY_REQUEST_PATTERNS = Array(".*views.py:.*") - private def containsRegex(str: String) = - val reChars = "[](){}*+&|?.,\\$" - str.exists(reChars.contains(_)) + private def containsRegex(str: String) = + val reChars = "[](){}*+&|?.,\\$" + str.exists(reChars.contains(_)) - private val BOM_JSON_FILE = ".*(bom|cdx).json" + private val BOM_JSON_FILE = ".*(bom|cdx).json" - private def toPyModuleForm(str: String) = - if str.nonEmpty then - val tmpParts = str.split("\\.") - if str.count(_ == '.') > 1 then - s"${tmpParts.take(2).mkString(Pattern.quote(File.separator))}.*" - else if str.count(_ == '.') == 1 then s"${tmpParts.head}.py:.*" - else s"$str.py:.*" - else - str + private def toPyModuleForm(str: String) = + if str.nonEmpty then + val tmpParts = str.split("\\.") + if str.count(_ == '.') > 1 then + s"${tmpParts.take(2).mkString(Pattern.quote(File.separator))}.*" + else if str.count(_ == '.') == 1 then s"${tmpParts.head}.py:.*" + else s"$str.py:.*" + else + str - override def run(dstGraph: DiffGraphBuilder): Unit = - atom.configFile.name(BOM_JSON_FILE).content.foreach { cdxData => - val cdxJson = parse(cdxData).getOrElse(Json.Null) - val cursor: HCursor = cdxJson.hcursor - val components = - cursor.downField("components").focus.flatMap(_.asArray).getOrElse(Vector.empty) - val donePkgs = mutable.Map[String, Boolean]() - if language == Languages.JSSRC || language == Languages.JAVASCRIPT then - JS_REQUEST_PATTERNS.foreach(p => - atom.call.code(p).newTagNode("framework-input").store()(dstGraph) - ) - JS_RESPONSE_PATTERNS.foreach(p => - atom.call.code(p).newTagNode("framework-output").store()(dstGraph) - ) - if language == Languages.PHP then - atom.parameter.name("request.*").newTagNode("framework-input").store()(dstGraph) - atom.parameter.name("response.*").newTagNode("framework-output").store()(dstGraph) - atom.ret - .where(_.method.parameter.name("request.*")) - .newTagNode("framework-output") - .store()(dstGraph) - if language == Languages.PYTHON || language == Languages.PYTHONSRC then - PY_REQUEST_PATTERNS - .foreach(p => - atom.method.fullName(p).parameter.newTagNode("framework-input").store()( + override def run(dstGraph: DiffGraphBuilder): Unit = + atom.configFile.name(BOM_JSON_FILE).content.foreach { cdxData => + val cdxJson = parse(cdxData).getOrElse(Json.Null) + val cursor: HCursor = cdxJson.hcursor + val components = + cursor.downField("components").focus.flatMap(_.asArray).getOrElse(Vector.empty) + val donePkgs = mutable.Map[String, Boolean]() + if language == Languages.JSSRC || language == Languages.JAVASCRIPT then + JS_REQUEST_PATTERNS.foreach(p => + atom.call.code(p).newTagNode("framework-input").store()(dstGraph) + ) + JS_RESPONSE_PATTERNS.foreach(p => + atom.call.code(p).newTagNode("framework-output").store()(dstGraph) + ) + if language == Languages.PHP then + atom.parameter.name("request.*").newTagNode("framework-input").store()(dstGraph) + atom.parameter.name("response.*").newTagNode("framework-output").store()(dstGraph) + atom.ret + .where(_.method.parameter.name("request.*")) + .newTagNode("framework-output") + .store()(dstGraph) + if language == Languages.PYTHON || language == Languages.PYTHONSRC then + PY_REQUEST_PATTERNS + .foreach(p => + atom.method.fullName(p).parameter.newTagNode("framework-input").store()( + dstGraph + ) + ) + components.foreach { comp => + val PURL_TYPE = "purl" + val compPurl = comp.hcursor.downField(PURL_TYPE).as[String].getOrElse("") + val compType = comp.hcursor.downField("type").as[String].getOrElse("") + val compDescription: String = + comp.hcursor.downField("description").as[String].getOrElse("") + val descTags = keywords.filter(k => + compDescription.toLowerCase().contains(" " + k) + ).take(TAGS_COUNT) + if (language == Languages.PYTHON || language == Languages.PYTHONSRC) && compPurl.startsWith( + "pkg:pypi" + ) + then + val pkgName = compPurl.split("@").head.replace("pkg:pypi/", "") + .replace("python-", "") + .replace("-", "_"); + Set( + pkgName, + pkgName.replace("flask_", ""), + pkgName.replace("django_", ""), + pkgName.replace("py", "") + ).foreach { ns => + Set(toPyModuleForm(ns), s"$ns${Pattern.quote(File.separator)}.*").foreach { + bpkg => + if bpkg.nonEmpty && !donePkgs.contains(bpkg) then + donePkgs.put(bpkg, true) + atom.call.where( + _.methodFullName(bpkg) + ).newTagNode(compPurl).store()(dstGraph) + atom.identifier.typeFullName(bpkg).newTagNode( + compPurl + ).store()(dstGraph) + } + } + end if + val properties = comp.hcursor.downField("properties").focus.flatMap( + _.asArray + ).getOrElse(Vector.empty) + properties.foreach { ns => + val nsstr = ns.hcursor.downField("value").as[String].getOrElse("") + val nsname = ns.hcursor.downField("name").as[String].getOrElse("") + // Skip the SrcFile, ResolvedUrl, GradleProfileName, cdx: properties + if nsname != "SrcFile" && nsname != "ResolvedUrl" && nsname != "GradleProfileName" && !nsname + .startsWith( + "cdx:" + ) + then + nsstr + .split("(\n|,)") + .filterNot(_.startsWith("java.")) + .filterNot(_.startsWith("com.sun")) + .filterNot(_.contains("test")) + .filterNot(_.contains("mock")) + .foreach { (pkg: String) => + var bpkg = pkg.takeWhile(_ != '$') + if language == Languages.JAVA || language == Languages.JAVASRC then + bpkg = bpkg.split("\\.").take(PKG_NS_SIZE).mkString(".").concat( + ".*" + ) + bpkg = + bpkg.replace(File.separator, Pattern.quote(File.separator)) + if language == Languages.JSSRC || language == Languages.JAVASCRIPT + then + bpkg = s".*${bpkg}.*" + bpkg = + bpkg.replace(File.separator, Pattern.quote(File.separator)) + if language == Languages.PYTHON || language == Languages.PYTHONSRC + then bpkg = toPyModuleForm(bpkg) + if language == Languages.PHP + then + bpkg = bpkg.replaceAll("""\\""", """\\\\""") + bpkg = s"""$bpkg.*""" + if bpkg.nonEmpty && !donePkgs.contains(bpkg) then + donePkgs.put(bpkg, true) + // C/C++ + if language == Languages.NEWC || language == Languages.C + then + atom.method.fullNameExact(bpkg).callIn( + NoResolve + ).newTagNode( + compPurl + ).store()(dstGraph) + atom.method.fullNameExact(bpkg).callIn( + NoResolve + ).newTagNode( + "library-call" + ).store()(dstGraph) + atom.method.fullNameExact(bpkg).newTagNode( + compPurl + ).store()(dstGraph) + if !containsRegex(bpkg) then + atom.parameter.typeFullName(s"$bpkg.*").newTagNode( + compPurl + ).store()(dstGraph) + atom.parameter.typeFullName(s"$bpkg.*").newTagNode( + "framework-input" + ).store()(dstGraph) + atom.parameter.typeFullName(s"$bpkg.*").method.callIn( + NoResolve + ).newTagNode( + compPurl + ).store()(dstGraph) + atom.call.code(s".*\\.$bpkg.*").newTagNode( + compPurl + ).store()(dstGraph) + atom.call.code(s".*\\.$bpkg.*").newTagNode( + "library-call" + ).store()(dstGraph) + atom.call.code(s"$bpkg->.*").newTagNode( + compPurl + ).store()(dstGraph) + atom.call.code(s"$bpkg->.*").newTagNode( + "library-call" + ).store()(dstGraph) + else + atom.parameter.typeFullName( + s"${Pattern.quote(bpkg)}.*" + ).newTagNode( + compPurl + ).store()(dstGraph) + atom.parameter.typeFullName( + s"${Pattern.quote(bpkg)}.*" + ).newTagNode( + "framework-input" + ).store()(dstGraph) + atom.parameter.typeFullName( + s"${Pattern.quote(bpkg)}.*" + ).method.callIn( + NoResolve + ).newTagNode( + compPurl + ).store()(dstGraph) + end if + else if !containsRegex(bpkg) then + atom.call.typeFullNameExact(bpkg).newTagNode( + compPurl + ).store()(dstGraph) + atom.identifier.typeFullNameExact(bpkg).newTagNode( + compPurl + ).store()(dstGraph) + atom.method.parameter.typeFullNameExact(bpkg).newTagNode( + compPurl + ).store()(dstGraph) + atom.method.fullName( + s"${Pattern.quote(bpkg)}.*" + ).newTagNode(compPurl).store()(dstGraph) + else + atom.call.typeFullName(bpkg).newTagNode(compPurl).store()( dstGraph ) - ) - components.foreach { comp => - val PURL_TYPE = "purl" - val compPurl = comp.hcursor.downField(PURL_TYPE).as[String].getOrElse("") - val compType = comp.hcursor.downField("type").as[String].getOrElse("") - val compDescription: String = - comp.hcursor.downField("description").as[String].getOrElse("") - val descTags = keywords.filter(k => - compDescription.toLowerCase().contains(" " + k) - ).take(TAGS_COUNT) - if (language == Languages.PYTHON || language == Languages.PYTHONSRC) && compPurl.startsWith( - "pkg:pypi" - ) - then - val pkgName = compPurl.split("@").head.replace("pkg:pypi/", "") - .replace("python-", "") - .replace("-", "_"); - Set( - pkgName, - pkgName.replace("flask_", ""), - pkgName.replace("django_", ""), - pkgName.replace("py", "") - ).foreach { ns => - Set(toPyModuleForm(ns), s"$ns${Pattern.quote(File.separator)}.*").foreach { - bpkg => - if bpkg.nonEmpty && !donePkgs.contains(bpkg) then - donePkgs.put(bpkg, true) - atom.call.where( - _.methodFullName(bpkg) - ).newTagNode(compPurl).store()(dstGraph) - atom.identifier.typeFullName(bpkg).newTagNode( - compPurl - ).store()(dstGraph) - } - } - end if - val properties = comp.hcursor.downField("properties").focus.flatMap( - _.asArray - ).getOrElse(Vector.empty) - properties.foreach { ns => - val nsstr = ns.hcursor.downField("value").as[String].getOrElse("") - val nsname = ns.hcursor.downField("name").as[String].getOrElse("") - // Skip the SrcFile, ResolvedUrl, GradleProfileName, cdx: properties - if nsname != "SrcFile" && nsname != "ResolvedUrl" && nsname != "GradleProfileName" && !nsname.startsWith( - "cdx:" + atom.identifier.typeFullName(bpkg).newTagNode( + compPurl + ).store()(dstGraph) + atom.method.parameter.typeFullName(bpkg).newTagNode( + compPurl + ).store()(dstGraph) + atom.method.fullName(bpkg).newTagNode(compPurl).store()( + dstGraph ) - then - nsstr - .split("(\n|,)") - .filterNot(_.startsWith("java.")) - .filterNot(_.startsWith("com.sun")) - .filterNot(_.contains("test")) - .filterNot(_.contains("mock")) - .foreach { (pkg: String) => - var bpkg = pkg.takeWhile(_ != '$') - if language == Languages.JAVA || language == Languages.JAVASRC then - bpkg = bpkg.split("\\.").take(PKG_NS_SIZE).mkString(".").concat( - ".*" - ) - bpkg = - bpkg.replace(File.separator, Pattern.quote(File.separator)) - if language == Languages.JSSRC || language == Languages.JAVASCRIPT - then - bpkg = s".*${bpkg}.*" - bpkg = - bpkg.replace(File.separator, Pattern.quote(File.separator)) - if language == Languages.PYTHON || language == Languages.PYTHONSRC - then bpkg = toPyModuleForm(bpkg) - if language == Languages.PHP - then - bpkg = bpkg.replaceAll("""\\""", """\\\\""") - bpkg = s"""$bpkg.*""" - if bpkg.nonEmpty && !donePkgs.contains(bpkg) then - donePkgs.put(bpkg, true) - // C/C++ - if language == Languages.NEWC || language == Languages.C - then - atom.method.fullNameExact(bpkg).callIn( - NoResolve - ).newTagNode( - compPurl - ).store()(dstGraph) - atom.method.fullNameExact(bpkg).callIn( - NoResolve - ).newTagNode( - "library-call" - ).store()(dstGraph) - atom.method.fullNameExact(bpkg).newTagNode( - compPurl - ).store()(dstGraph) - if !containsRegex(bpkg) then - atom.parameter.typeFullName(s"$bpkg.*").newTagNode( - compPurl - ).store()(dstGraph) - atom.parameter.typeFullName(s"$bpkg.*").newTagNode( - "framework-input" - ).store()(dstGraph) - atom.parameter.typeFullName(s"$bpkg.*").method.callIn( - NoResolve - ).newTagNode( - compPurl - ).store()(dstGraph) - atom.call.code(s".*\\.$bpkg.*").newTagNode( - compPurl - ).store()(dstGraph) - atom.call.code(s".*\\.$bpkg.*").newTagNode( - "library-call" - ).store()(dstGraph) - atom.call.code(s"$bpkg->.*").newTagNode( - compPurl - ).store()(dstGraph) - atom.call.code(s"$bpkg->.*").newTagNode( - "library-call" - ).store()(dstGraph) - else - atom.parameter.typeFullName( - s"${Pattern.quote(bpkg)}.*" - ).newTagNode( - compPurl - ).store()(dstGraph) - atom.parameter.typeFullName( - s"${Pattern.quote(bpkg)}.*" - ).newTagNode( - "framework-input" - ).store()(dstGraph) - atom.parameter.typeFullName( - s"${Pattern.quote(bpkg)}.*" - ).method.callIn( - NoResolve - ).newTagNode( - compPurl - ).store()(dstGraph) - end if - else if !containsRegex(bpkg) then - atom.call.typeFullNameExact(bpkg).newTagNode( - compPurl - ).store()(dstGraph) - atom.identifier.typeFullNameExact(bpkg).newTagNode( - compPurl - ).store()(dstGraph) - atom.method.parameter.typeFullNameExact(bpkg).newTagNode( - compPurl - ).store()(dstGraph) - atom.method.fullName( - s"${Pattern.quote(bpkg)}.*" - ).newTagNode(compPurl).store()(dstGraph) - else - atom.call.typeFullName(bpkg).newTagNode(compPurl).store()( - dstGraph - ) - atom.identifier.typeFullName(bpkg).newTagNode( - compPurl - ).store()(dstGraph) - atom.method.parameter.typeFullName(bpkg).newTagNode( - compPurl - ).store()(dstGraph) - atom.method.fullName(bpkg).newTagNode(compPurl).store()( - dstGraph - ) - if language == Languages.JSSRC || language == Languages.JAVASCRIPT - then - atom.call.code(bpkg).argument.newTagNode( - compPurl - ).store()(dstGraph) - atom.identifier.code(bpkg).newTagNode(compPurl).store()( - dstGraph - ) - atom.identifier.code(bpkg).inCall.newTagNode( - compPurl - ).store()(dstGraph) - if language == Languages.PYTHON || language == Languages.PYTHONSRC - then - atom.call.where( - _.methodFullName(bpkg) - ).argument.newTagNode(compPurl).store()(dstGraph) - atom.identifier.typeFullName(bpkg).newTagNode( - compPurl - ).store()(dstGraph) - end if - if compType != "library" then - if !containsRegex(bpkg) then - atom.call.typeFullNameExact(bpkg).newTagNode( - compType - ).store()(dstGraph) - atom.call.typeFullNameExact(bpkg).receiver.newTagNode( - s"$compType-value" - ).store()(dstGraph) - atom.method.parameter.typeFullNameExact( - bpkg - ).newTagNode(compType).store()(dstGraph) - atom.method.fullName( - s"${Pattern.quote(bpkg)}.*" - ).newTagNode(compType).store()(dstGraph) - else - atom.call.typeFullName(bpkg).newTagNode( - compType - ).store()(dstGraph) - atom.call.typeFullName(bpkg).receiver.newTagNode( - s"$compType-value" - ).store()(dstGraph) - atom.method.parameter.typeFullName(bpkg).newTagNode( - compType - ).store()(dstGraph) - atom.method.fullName(bpkg).newTagNode(compType).store()( - dstGraph - ) - if language == Languages.JSSRC || language == Languages.JAVASCRIPT - then - atom.call.code(bpkg).argument.newTagNode( - compType - ).store()(dstGraph) - atom.identifier.code(bpkg).newTagNode( - compType - ).store()(dstGraph) - atom.identifier.code(bpkg).inCall.newTagNode( - compType - ).store()(dstGraph) - if language == Languages.PYTHON || language == Languages.PYTHONSRC - then - atom.call.where( - _.methodFullName(bpkg) - ).argument.newTagNode(compType).store()(dstGraph) - end if - if compType == "framework" then - def frameworkAnnotatedMethod = atom.annotation - .fullName(bpkg) - .method + if language == Languages.JSSRC || language == Languages.JAVASCRIPT + then + atom.call.code(bpkg).argument.newTagNode( + compPurl + ).store()(dstGraph) + atom.identifier.code(bpkg).newTagNode(compPurl).store()( + dstGraph + ) + atom.identifier.code(bpkg).inCall.newTagNode( + compPurl + ).store()(dstGraph) + if language == Languages.PYTHON || language == Languages.PYTHONSRC + then + atom.call.where( + _.methodFullName(bpkg) + ).argument.newTagNode(compPurl).store()(dstGraph) + atom.identifier.typeFullName(bpkg).newTagNode( + compPurl + ).store()(dstGraph) + end if + if compType != "library" then + if !containsRegex(bpkg) then + atom.call.typeFullNameExact(bpkg).newTagNode( + compType + ).store()(dstGraph) + atom.call.typeFullNameExact(bpkg).receiver.newTagNode( + s"$compType-value" + ).store()(dstGraph) + atom.method.parameter.typeFullNameExact( + bpkg + ).newTagNode(compType).store()(dstGraph) + atom.method.fullName( + s"${Pattern.quote(bpkg)}.*" + ).newTagNode(compType).store()(dstGraph) + else + atom.call.typeFullName(bpkg).newTagNode( + compType + ).store()(dstGraph) + atom.call.typeFullName(bpkg).receiver.newTagNode( + s"$compType-value" + ).store()(dstGraph) + atom.method.parameter.typeFullName(bpkg).newTagNode( + compType + ).store()(dstGraph) + atom.method.fullName(bpkg).newTagNode(compType).store()( + dstGraph + ) + if language == Languages.JSSRC || language == Languages.JAVASCRIPT + then + atom.call.code(bpkg).argument.newTagNode( + compType + ).store()(dstGraph) + atom.identifier.code(bpkg).newTagNode( + compType + ).store()(dstGraph) + atom.identifier.code(bpkg).inCall.newTagNode( + compType + ).store()(dstGraph) + if language == Languages.PYTHON || language == Languages.PYTHONSRC + then + atom.call.where( + _.methodFullName(bpkg) + ).argument.newTagNode(compType).store()(dstGraph) + end if + if compType == "framework" then + def frameworkAnnotatedMethod = atom.annotation + .fullName(bpkg) + .method - frameworkAnnotatedMethod.parameter - .newTagNode(s"$compType-input") - .store()(dstGraph) - atom.ret - .where(_.method.annotation.fullName(bpkg)) - .newTagNode(s"$compType-output") - .store()(dstGraph) - descTags.foreach { t => - atom.call.typeFullName(bpkg).newTagNode(t).store()(dstGraph) - atom.identifier.typeFullName(bpkg).newTagNode(t).store()( - dstGraph - ) - atom.method.parameter.typeFullName(bpkg).newTagNode( - t - ).store()(dstGraph) - if !containsRegex(bpkg) then - atom.method.fullName( - s"${Pattern.quote(bpkg)}.*" - ).newTagNode(t).store()(dstGraph) - else - atom.method.fullName(bpkg).newTagNode(t).store()( - dstGraph - ) - if language == Languages.PYTHON || language == Languages.PYTHONSRC - then - atom.call.where( - _.methodFullName(bpkg) - ).newTagNode(t).store()(dstGraph) - atom.identifier.typeFullName(bpkg).newTagNode( - t - ).store()(dstGraph) - } - end if - } + frameworkAnnotatedMethod.parameter + .newTagNode(s"$compType-input") + .store()(dstGraph) + atom.ret + .where(_.method.annotation.fullName(bpkg)) + .newTagNode(s"$compType-output") + .store()(dstGraph) + descTags.foreach { t => + atom.call.typeFullName(bpkg).newTagNode(t).store()(dstGraph) + atom.identifier.typeFullName(bpkg).newTagNode(t).store()( + dstGraph + ) + atom.method.parameter.typeFullName(bpkg).newTagNode( + t + ).store()(dstGraph) + if !containsRegex(bpkg) then + atom.method.fullName( + s"${Pattern.quote(bpkg)}.*" + ).newTagNode(t).store()(dstGraph) + else + atom.method.fullName(bpkg).newTagNode(t).store()( + dstGraph + ) + if language == Languages.PYTHON || language == Languages.PYTHONSRC + then + atom.call.where( + _.methodFullName(bpkg) + ).newTagNode(t).store()(dstGraph) + atom.identifier.typeFullName(bpkg).newTagNode( + t + ).store()(dstGraph) + } end if - } - } + } + end if + } } - end run + } + end run end CdxPass diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/taggers/ChennaiTagsPass.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/taggers/ChennaiTagsPass.scala index ac53a00c..dbf9baf4 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/taggers/ChennaiTagsPass.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/taggers/ChennaiTagsPass.scala @@ -14,178 +14,179 @@ import java.util.regex.Pattern */ class ChennaiTagsPass(atom: Cpg) extends CpgPass(atom): - val language: String = atom.metaData.language.head - private val FRAMEWORK_ROUTE = "framework-route" - private val FRAMEWORK_INPUT = "framework-input" - private val FRAMEWORK_OUTPUT = "framework-output" - private val EscapedFileSeparator = Pattern.quote(java.io.File.separator) + val language: String = atom.metaData.language.head + private val FRAMEWORK_ROUTE = "framework-route" + private val FRAMEWORK_INPUT = "framework-input" + private val FRAMEWORK_OUTPUT = "framework-output" + private val EscapedFileSeparator = Pattern.quote(java.io.File.separator) - private val PYTHON_ROUTES_CALL_REGEXES = - Array( - s"django$EscapedFileSeparator(conf$EscapedFileSeparator)?urls.py:.(path|re_path|url).*".r, - ".*(route|web\\.|add_resource).*".r - ) + private val PYTHON_ROUTES_CALL_REGEXES = + Array( + s"django$EscapedFileSeparator(conf$EscapedFileSeparator)?urls.py:.(path|re_path|url).*" + .r, + ".*(route|web\\.|add_resource).*".r + ) - private def C_ROUTES_CALL_REGEXES = Array( - "Routes::(Post|Get|Delete|Head|Options|Put).*", - "API_CALL", - "API_CALL_ASYNC", - "ENDPOINT", - "ENDPOINT_ASYNC", - "ENDPOINT_INTERCEPTOR", - "ENDPOINT_INTERCEPTOR_ASYNC", - "registerHandler", - "PATH_ADD", - "ADD_METHOD_TO", - "ADD_METHOD_VIA_REGEX", - "WS_PATH_ADD", - "svr\\.(Post|Get|Delete|Head|Options|Put)" - ) - private val PYTHON_ROUTES_DECORATORS_REGEXES = Array( - ".*(route|endpoint|_request|require_http_methods|require_GET|require_POST|require_safe|_required|api\\.doc|api\\.response|api\\.errorhandler)\\(.*", - ".*def\\s(get|post|put)\\(.*" - ) - private val PHP_ROUTES_METHODS_REGEXES = Array( - ".*(router|routes|r|app|map)->(addRoute|add|before|mount|get|post|put|delete|head|option).*", - ".*(Router)::(scope|connect|get|post|put|delete|head|option).*" - ) - private val HTTP_METHODS_REGEX = ".*(request|session)\\.(args|get|post|put|form).*" + private def C_ROUTES_CALL_REGEXES = Array( + "Routes::(Post|Get|Delete|Head|Options|Put).*", + "API_CALL", + "API_CALL_ASYNC", + "ENDPOINT", + "ENDPOINT_ASYNC", + "ENDPOINT_INTERCEPTOR", + "ENDPOINT_INTERCEPTOR_ASYNC", + "registerHandler", + "PATH_ADD", + "ADD_METHOD_TO", + "ADD_METHOD_VIA_REGEX", + "WS_PATH_ADD", + "svr\\.(Post|Get|Delete|Head|Options|Put)" + ) + private val PYTHON_ROUTES_DECORATORS_REGEXES = Array( + ".*(route|endpoint|_request|require_http_methods|require_GET|require_POST|require_safe|_required|api\\.doc|api\\.response|api\\.errorhandler)\\(.*", + ".*def\\s(get|post|put)\\(.*" + ) + private val PHP_ROUTES_METHODS_REGEXES = Array( + ".*(router|routes|r|app|map)->(addRoute|add|before|mount|get|post|put|delete|head|option).*", + ".*(Router)::(scope|connect|get|post|put|delete|head|option).*" + ) + private val HTTP_METHODS_REGEX = ".*(request|session)\\.(args|get|post|put|form).*" - private def containsRegex(str: String) = - val reChars = "[](){}*+&|?.,\\$" - str.exists(reChars.contains(_)) + private def containsRegex(str: String) = + val reChars = "[](){}*+&|?.,\\$" + str.exists(reChars.contains(_)) - private def tagCRoutes(dstGraph: DiffGraphBuilder): Unit = - C_ROUTES_CALL_REGEXES.foreach { r => - atom.method.fullName(r).parameter.newTagNode(FRAMEWORK_INPUT).store()( + private def tagCRoutes(dstGraph: DiffGraphBuilder): Unit = + C_ROUTES_CALL_REGEXES.foreach { r => + atom.method.fullName(r).parameter.newTagNode(FRAMEWORK_INPUT).store()( + dstGraph + ) + atom.call + .where(_.methodFullName(r)) + .argument + .isLiteral + .newTagNode(FRAMEWORK_ROUTE) + .store()(dstGraph) + } + private def tagPythonRoutes(dstGraph: DiffGraphBuilder): Unit = + PYTHON_ROUTES_CALL_REGEXES.foreach { r => + atom.call + .where(_.methodFullName(r.toString())) + .argument + .isLiteral + .newTagNode(FRAMEWORK_ROUTE) + .store()(dstGraph) + } + PYTHON_ROUTES_DECORATORS_REGEXES.foreach { r => + def decoratedMethods = atom.methodRef + .where(_.inCall.code(r).argument) + ._refOut + .collectAll[Method] + decoratedMethods.call.assignment + .code(HTTP_METHODS_REGEX) + .argument + .isIdentifier + .newTagNode(FRAMEWORK_INPUT) + .store()(dstGraph) + decoratedMethods + .newTagNode(FRAMEWORK_INPUT) + .store()(dstGraph) + decoratedMethods.parameter + .newTagNode(FRAMEWORK_INPUT) + .store()(dstGraph) + } + atom.file.name(".*views.py.*").method.parameter.name("request").method.newTagNode( + FRAMEWORK_INPUT + ).store()(dstGraph) + atom.file.name(".*controllers.*.py.*").method.name( + "get|post|put|delete|head|option" + ).parameter.filterNot(_.name == "self").newTagNode( + FRAMEWORK_INPUT + ).store()(dstGraph) + atom.file.name(".*controllers.*.py.*").method.name( + "get|post|put|delete|head|option" + ).methodReturn.newTagNode( + FRAMEWORK_OUTPUT + ).store()(dstGraph) + end tagPythonRoutes + private def tagPhpRoutes(dstGraph: DiffGraphBuilder): Unit = + PHP_ROUTES_METHODS_REGEXES.foreach { r => + atom.method.fullName(r).parameter.newTagNode(FRAMEWORK_INPUT).store()( + dstGraph + ) + atom.call.where(_.methodFullName(r)).argument.isLiteral.newTagNode( + FRAMEWORK_ROUTE + ).store()(dstGraph) + } + end tagPhpRoutes + override def run(dstGraph: DiffGraphBuilder): Unit = + if language == Languages.PYTHON || language == Languages.PYTHONSRC then + tagPythonRoutes(dstGraph) + if language == Languages.NEWC || language == Languages.C then + tagCRoutes(dstGraph) + if language == Languages.PHP then tagPhpRoutes(dstGraph) + atom.configFile("chennai.json").content.foreach { cdxData => + val ctagsJson = parse(cdxData).getOrElse(Json.Null) + val cursor: HCursor = ctagsJson.hcursor + val tags = cursor.downField("tags").focus.flatMap(_.asArray).getOrElse(Vector.empty) + tags.foreach { comp => + val tagName = comp.hcursor.downField("name").as[String].getOrElse("") + val tagParams = comp.hcursor.downField("parameters").focus.flatMap( + _.asArray + ).getOrElse(Vector.empty) + val tagMethods = comp.hcursor.downField("methods").focus.flatMap( + _.asArray + ).getOrElse(Vector.empty) + val tagTypes = + comp.hcursor.downField("types").focus.flatMap(_.asArray).getOrElse(Vector.empty) + val tagFiles = + comp.hcursor.downField("files").focus.flatMap(_.asArray).getOrElse(Vector.empty) + tagParams.foreach { paramName => + val pn = paramName.asString.getOrElse("") + if pn.nonEmpty then + atom.method.parameter.typeFullNameExact(pn).newTagNode(tagName).store()( dstGraph ) - atom.call - .where(_.methodFullName(r)) - .argument - .isLiteral - .newTagNode(FRAMEWORK_ROUTE) - .store()(dstGraph) + if !containsRegex(pn) then + atom.method.parameter.typeFullName( + s".*${Pattern.quote(pn)}.*" + ).newTagNode(tagName).store()(dstGraph) } - private def tagPythonRoutes(dstGraph: DiffGraphBuilder): Unit = - PYTHON_ROUTES_CALL_REGEXES.foreach { r => - atom.call - .where(_.methodFullName(r.toString())) - .argument - .isLiteral - .newTagNode(FRAMEWORK_ROUTE) - .store()(dstGraph) + tagMethods.foreach { methodName => + val mn = methodName.asString.getOrElse("") + if mn.nonEmpty then + atom.method.fullNameExact(mn).newTagNode(tagName).store()(dstGraph) + if !containsRegex(mn) then + atom.method.fullName(s".*${Pattern.quote(mn)}.*").newTagNode( + tagName + ).store()(dstGraph) } - PYTHON_ROUTES_DECORATORS_REGEXES.foreach { r => - def decoratedMethods = atom.methodRef - .where(_.inCall.code(r).argument) - ._refOut - .collectAll[Method] - decoratedMethods.call.assignment - .code(HTTP_METHODS_REGEX) - .argument - .isIdentifier - .newTagNode(FRAMEWORK_INPUT) - .store()(dstGraph) - decoratedMethods - .newTagNode(FRAMEWORK_INPUT) - .store()(dstGraph) - decoratedMethods.parameter - .newTagNode(FRAMEWORK_INPUT) - .store()(dstGraph) - } - atom.file.name(".*views.py.*").method.parameter.name("request").method.newTagNode( - FRAMEWORK_INPUT - ).store()(dstGraph) - atom.file.name(".*controllers.*.py.*").method.name( - "get|post|put|delete|head|option" - ).parameter.filterNot(_.name == "self").newTagNode( - FRAMEWORK_INPUT - ).store()(dstGraph) - atom.file.name(".*controllers.*.py.*").method.name( - "get|post|put|delete|head|option" - ).methodReturn.newTagNode( - FRAMEWORK_OUTPUT - ).store()(dstGraph) - end tagPythonRoutes - private def tagPhpRoutes(dstGraph: DiffGraphBuilder): Unit = - PHP_ROUTES_METHODS_REGEXES.foreach { r => - atom.method.fullName(r).parameter.newTagNode(FRAMEWORK_INPUT).store()( + tagTypes.foreach { typeName => + val tn = typeName.asString.getOrElse("") + if tn.nonEmpty then + atom.method.parameter.typeFullNameExact(tn).newTagNode(tagName).store()( dstGraph ) - atom.call.where(_.methodFullName(r)).argument.isLiteral.newTagNode( - FRAMEWORK_ROUTE - ).store()(dstGraph) + if !containsRegex(tn) then + atom.method.parameter.typeFullName( + s".*${Pattern.quote(tn)}.*" + ).newTagNode(tagName).store()(dstGraph) + atom.call.typeFullNameExact(tn).newTagNode(tagName).store()(dstGraph) + if !tn.contains("[") && !tn.contains("*") then + atom.call.typeFullName(s".*${Pattern.quote(tn)}.*").newTagNode( + tagName + ).store()(dstGraph) } - end tagPhpRoutes - override def run(dstGraph: DiffGraphBuilder): Unit = - if language == Languages.PYTHON || language == Languages.PYTHONSRC then - tagPythonRoutes(dstGraph) - if language == Languages.NEWC || language == Languages.C then - tagCRoutes(dstGraph) - if language == Languages.PHP then tagPhpRoutes(dstGraph) - atom.configFile("chennai.json").content.foreach { cdxData => - val ctagsJson = parse(cdxData).getOrElse(Json.Null) - val cursor: HCursor = ctagsJson.hcursor - val tags = cursor.downField("tags").focus.flatMap(_.asArray).getOrElse(Vector.empty) - tags.foreach { comp => - val tagName = comp.hcursor.downField("name").as[String].getOrElse("") - val tagParams = comp.hcursor.downField("parameters").focus.flatMap( - _.asArray - ).getOrElse(Vector.empty) - val tagMethods = comp.hcursor.downField("methods").focus.flatMap( - _.asArray - ).getOrElse(Vector.empty) - val tagTypes = - comp.hcursor.downField("types").focus.flatMap(_.asArray).getOrElse(Vector.empty) - val tagFiles = - comp.hcursor.downField("files").focus.flatMap(_.asArray).getOrElse(Vector.empty) - tagParams.foreach { paramName => - val pn = paramName.asString.getOrElse("") - if pn.nonEmpty then - atom.method.parameter.typeFullNameExact(pn).newTagNode(tagName).store()( - dstGraph - ) - if !containsRegex(pn) then - atom.method.parameter.typeFullName( - s".*${Pattern.quote(pn)}.*" - ).newTagNode(tagName).store()(dstGraph) - } - tagMethods.foreach { methodName => - val mn = methodName.asString.getOrElse("") - if mn.nonEmpty then - atom.method.fullNameExact(mn).newTagNode(tagName).store()(dstGraph) - if !containsRegex(mn) then - atom.method.fullName(s".*${Pattern.quote(mn)}.*").newTagNode( - tagName - ).store()(dstGraph) - } - tagTypes.foreach { typeName => - val tn = typeName.asString.getOrElse("") - if tn.nonEmpty then - atom.method.parameter.typeFullNameExact(tn).newTagNode(tagName).store()( - dstGraph - ) - if !containsRegex(tn) then - atom.method.parameter.typeFullName( - s".*${Pattern.quote(tn)}.*" - ).newTagNode(tagName).store()(dstGraph) - atom.call.typeFullNameExact(tn).newTagNode(tagName).store()(dstGraph) - if !tn.contains("[") && !tn.contains("*") then - atom.call.typeFullName(s".*${Pattern.quote(tn)}.*").newTagNode( - tagName - ).store()(dstGraph) - } - tagFiles.foreach { fileName => - val fn = fileName.asString.getOrElse("") - if fn.nonEmpty then - atom.file.nameExact(fn).newTagNode(tagName).store()(dstGraph) - if !containsRegex(fn) then - atom.file.name(s".*${Pattern.quote(fn)}.*").newTagNode(tagName).store()( - dstGraph - ) - } - } + tagFiles.foreach { fileName => + val fn = fileName.asString.getOrElse("") + if fn.nonEmpty then + atom.file.nameExact(fn).newTagNode(tagName).store()(dstGraph) + if !containsRegex(fn) then + atom.file.name(s".*${Pattern.quote(fn)}.*").newTagNode(tagName).store()( + dstGraph + ) } - end run + } + } + end run end ChennaiTagsPass diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/taggers/EasyTagsPass.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/taggers/EasyTagsPass.scala index 067e8ac9..16277d90 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/taggers/EasyTagsPass.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/taggers/EasyTagsPass.scala @@ -9,96 +9,96 @@ import io.shiftleft.semanticcpg.language.* */ class EasyTagsPass(atom: Cpg) extends CpgPass(atom): - val language: String = atom.metaData.language.head + val language: String = atom.metaData.language.head - override def run(dstGraph: DiffGraphBuilder): Unit = - atom.method.internal.name(".*(valid|check).*").newTagNode("validation").store()(dstGraph) - atom.method.internal.name("is[A-Z].*").newTagNode("validation").store()(dstGraph) - atom.method.internal.name(".*(encode|escape|sanit).*").newTagNode("sanitization").store()( - dstGraph - ) - atom.method.internal.name(".*(login|authenti).*").newTagNode("authentication").store()( - dstGraph - ) - atom.method.internal.name(".*(authori).*").newTagNode("authorization").store()(dstGraph) - if language == Languages.JSSRC || language == Languages.JAVASCRIPT then - // Tag cli source - atom.method.internal.fullName("(index|app).(js|jsx|ts|tsx)::program").newTagNode( - "cli-source" - ).store()( - dstGraph - ) - // Tag exported methods - atom.call.where(_.methodFullName(Operators.assignment)).code( - "(module\\.)?exports.*" - ).argument.isCall.methodFullName.filterNot(_.startsWith("<")).foreach { m => - atom.method.nameExact(m).newTagNode("exported").store()(dstGraph) - } - else if language == Languages.PYTHON || language == Languages.PYTHONSRC then - atom.method.internal.name("is_[a-z].*").newTagNode("validation").store()(dstGraph) - atom.call.methodFullName(Operators.equals).codeExact( - "__name__ == '__main__'" - ).controls.isCall.filterNot(_.name.startsWith(" + atom.method.nameExact(m).newTagNode("exported").store()(dstGraph) + } + else if language == Languages.PYTHON || language == Languages.PYTHONSRC then + atom.method.internal.name("is_[a-z].*").newTagNode("validation").store()(dstGraph) + atom.call.methodFullName(Operators.equals).codeExact( + "__name__ == '__main__'" + ).controls.isCall.filterNot(_.name.startsWith(" atom.method.nameExact(a.code).parameter.newTagNode("framework-input").store()( dstGraph @@ -109,68 +109,68 @@ class EasyTagsPass(atom: Cpg) extends CpgPass(atom): dstGraph ) } - */ - atom.method.name("wp_cron").newTagNode("cron").store()(dstGraph) - atom.method.name("wp_mail").newTagNode("mail").store()(dstGraph) - atom.method.name("wp_signon").newTagNode("authentication").store()(dstGraph) - atom.method.name("wp_remote_.*").newTagNode("http").store()(dstGraph) - end if - if language == Languages.JAVA || language == Languages.JAVASRC then - atom.identifier.typeFullName("java.security.*").newTagNode("crypto").store()(dstGraph) - atom.identifier.typeFullName("org.bouncycastle.*").newTagNode("crypto").store()( - dstGraph - ) - atom.identifier.typeFullName("org.apache.xml.security.*").newTagNode("crypto").store()( - dstGraph - ) - atom.identifier.typeFullName("javax.(security|crypto).*").newTagNode("crypto").store()( - dstGraph - ) - atom.call.methodFullName("java.security.*").newTagNode("crypto").store()(dstGraph) - atom.call.methodFullName("org.bouncycastle.*").newTagNode("crypto").store()(dstGraph) - atom.call.methodFullName("org.apache.xml.security.*").newTagNode("crypto").store()( - dstGraph - ) - atom.call.methodFullName("javax.(security|crypto).*").newTagNode("crypto").store()( - dstGraph - ) - atom.call.methodFullName("java.security.*doFinal.*").newTagNode( - "crypto-generate" - ).store()(dstGraph) - atom.call.methodFullName("org.bouncycastle.*(doFinal|generate|build).*").newTagNode( - "crypto-generate" - ).store()(dstGraph) - atom.call.methodFullName( - "org.apache.xml.security.*(doFinal|create|decrypt|encrypt|load|martial).*" - ).newTagNode( - "crypto-generate" - ).store()(dstGraph) - atom.call.methodFullName("javax.(security|crypto).*doFinal.*").newTagNode( - "crypto-generate" - ).store()( - dstGraph - ) - atom.literal.code( - "\"(DSA|ECDSA|GOST-3410|ECGOST-3410|MD5|SHA1|SHA224|SHA384|SHA512|ECDH|PKCS12|DES|DESEDE|IDEA|RC2|RC5|MD2|MD4|MD5|RIPEMD128|RIPEMD160|RIPEMD256|AES|Blowfish|CAST5|CAST6|DES|DESEDE|GOST-28147|IDEA|RC6|Rijndael|Serpent|Skipjack|Twofish|OpenPGPCFB|PKCS7Padding|ISO10126-2Padding|ISO7816-4Padding|TBCPadding|X9.23Padding|ZeroBytePadding|PBEWithMD5AndDES|PBEWithSHA1AndDES|PBEWithSHA1AndRC2|PBEWithMD5AndRC2|PBEWithSHA1AndIDEA|PBEWithSHA1And3-KeyTripleDES|PBEWithSHA1And2-KeyTripleDES|PBEWithSHA1And40BitRC2|PBEWithSHA1And40BitRC4|PBEWithSHA1And128BitRC2|PBEWithSHA1And128BitRC4|PBEWithSHA1AndTwofish|ChaCha20|ChaCha20-Poly1305|DESede|DiffieHellman|OAEP|PBEWithMD5AndDES|PBEWithHmacSHA256AndAES|RSASSA-PSS|X25519|X448|XDH|X.509|PKCS7|PkiPath|PKIX|AESWrap|ARCFOUR|ISO10126Padding|OAEPWithMD5AndMGF1Padding|OAEPWithSHA-512AndMGF1Padding|PKCS1Padding|PKCS5Padding|SSL3Padding|ECMQV|HmacMD5|HmacSHA1|HmacSHA224|HmacSHA256|HmacSHA384|HmacSHA512|HmacSHA3-224|HmacSHA3-256|HmacSHA3-384|HmacSHA3-512|SHA3-224|SHA3-256|SHA3-384|SHA3-512|SHA-1|SHA-224|SHA-256|SHA-384|SHA-512|CRAM-MD5|DIGEST-MD5|GSSAPI|NTLM|PBKDF2WithHmacSHA256|NativePRNG|NativePRNGBlocking|NativePRNGNonBlocking|SHA1PRNG|Windows-PRNG|NONEwithRSA|MD2withRSA|MD5withRSA|SHA1withRSA|SHA224withRSA|SHA256withRSA|SHA384withRSA|SHA512withRSA|SHA3-224withRSA|SHA3-256withRSA|SHA3-384withRSA|SHA3-512withRSA|NONEwithECDSAinP1363Format|SHA1withECDSAinP1363Format|SHA224withECDSAinP1363Format|SHA256withECDSAinP1363Format|SHA384withECDSAinP1363Format|SHA512withECDSAinP1363Format|SSLv2|SSLv3|TLSv1|DTLS|SSL_|TLS_).*" - ).newTagNode("crypto-algorithm").store()(dstGraph) - end if - if language == Languages.PYTHON || language == Languages.PYTHONSRC then - val known_crypto_libs = "(cryptography|Crypto|ecdsa|nacl).*" - atom.identifier.typeFullName(known_crypto_libs).newTagNode( - "crypto" - ).store()(dstGraph) - atom.call.methodFullName(known_crypto_libs).newTagNode( - "crypto" - ).store()(dstGraph) - atom.call.methodFullName( - s"${known_crypto_libs}(generate|encrypt|decrypt|derive|sign|public_bytes|private_bytes|exchange|new|update|export_key|import_key|from_string|from_pem|to_pem).*" - ).newTagNode( - "crypto-generate" - ).store()(dstGraph) - atom.call.name("[A-Z0-9]+").methodFullName( - s"${known_crypto_libs}(primitives|serialization).*" - ).argument.inCall.newTagNode( - "crypto-algorithm" - ).store()(dstGraph) - end run + */ + atom.method.name("wp_cron").newTagNode("cron").store()(dstGraph) + atom.method.name("wp_mail").newTagNode("mail").store()(dstGraph) + atom.method.name("wp_signon").newTagNode("authentication").store()(dstGraph) + atom.method.name("wp_remote_.*").newTagNode("http").store()(dstGraph) + end if + if language == Languages.JAVA || language == Languages.JAVASRC then + atom.identifier.typeFullName("java.security.*").newTagNode("crypto").store()(dstGraph) + atom.identifier.typeFullName("org.bouncycastle.*").newTagNode("crypto").store()( + dstGraph + ) + atom.identifier.typeFullName("org.apache.xml.security.*").newTagNode("crypto").store()( + dstGraph + ) + atom.identifier.typeFullName("javax.(security|crypto).*").newTagNode("crypto").store()( + dstGraph + ) + atom.call.methodFullName("java.security.*").newTagNode("crypto").store()(dstGraph) + atom.call.methodFullName("org.bouncycastle.*").newTagNode("crypto").store()(dstGraph) + atom.call.methodFullName("org.apache.xml.security.*").newTagNode("crypto").store()( + dstGraph + ) + atom.call.methodFullName("javax.(security|crypto).*").newTagNode("crypto").store()( + dstGraph + ) + atom.call.methodFullName("java.security.*doFinal.*").newTagNode( + "crypto-generate" + ).store()(dstGraph) + atom.call.methodFullName("org.bouncycastle.*(doFinal|generate|build).*").newTagNode( + "crypto-generate" + ).store()(dstGraph) + atom.call.methodFullName( + "org.apache.xml.security.*(doFinal|create|decrypt|encrypt|load|martial).*" + ).newTagNode( + "crypto-generate" + ).store()(dstGraph) + atom.call.methodFullName("javax.(security|crypto).*doFinal.*").newTagNode( + "crypto-generate" + ).store()( + dstGraph + ) + atom.literal.code( + "\"(DSA|ECDSA|GOST-3410|ECGOST-3410|MD5|SHA1|SHA224|SHA384|SHA512|ECDH|PKCS12|DES|DESEDE|IDEA|RC2|RC5|MD2|MD4|MD5|RIPEMD128|RIPEMD160|RIPEMD256|AES|Blowfish|CAST5|CAST6|DES|DESEDE|GOST-28147|IDEA|RC6|Rijndael|Serpent|Skipjack|Twofish|OpenPGPCFB|PKCS7Padding|ISO10126-2Padding|ISO7816-4Padding|TBCPadding|X9.23Padding|ZeroBytePadding|PBEWithMD5AndDES|PBEWithSHA1AndDES|PBEWithSHA1AndRC2|PBEWithMD5AndRC2|PBEWithSHA1AndIDEA|PBEWithSHA1And3-KeyTripleDES|PBEWithSHA1And2-KeyTripleDES|PBEWithSHA1And40BitRC2|PBEWithSHA1And40BitRC4|PBEWithSHA1And128BitRC2|PBEWithSHA1And128BitRC4|PBEWithSHA1AndTwofish|ChaCha20|ChaCha20-Poly1305|DESede|DiffieHellman|OAEP|PBEWithMD5AndDES|PBEWithHmacSHA256AndAES|RSASSA-PSS|X25519|X448|XDH|X.509|PKCS7|PkiPath|PKIX|AESWrap|ARCFOUR|ISO10126Padding|OAEPWithMD5AndMGF1Padding|OAEPWithSHA-512AndMGF1Padding|PKCS1Padding|PKCS5Padding|SSL3Padding|ECMQV|HmacMD5|HmacSHA1|HmacSHA224|HmacSHA256|HmacSHA384|HmacSHA512|HmacSHA3-224|HmacSHA3-256|HmacSHA3-384|HmacSHA3-512|SHA3-224|SHA3-256|SHA3-384|SHA3-512|SHA-1|SHA-224|SHA-256|SHA-384|SHA-512|CRAM-MD5|DIGEST-MD5|GSSAPI|NTLM|PBKDF2WithHmacSHA256|NativePRNG|NativePRNGBlocking|NativePRNGNonBlocking|SHA1PRNG|Windows-PRNG|NONEwithRSA|MD2withRSA|MD5withRSA|SHA1withRSA|SHA224withRSA|SHA256withRSA|SHA384withRSA|SHA512withRSA|SHA3-224withRSA|SHA3-256withRSA|SHA3-384withRSA|SHA3-512withRSA|NONEwithECDSAinP1363Format|SHA1withECDSAinP1363Format|SHA224withECDSAinP1363Format|SHA256withECDSAinP1363Format|SHA384withECDSAinP1363Format|SHA512withECDSAinP1363Format|SSLv2|SSLv3|TLSv1|DTLS|SSL_|TLS_).*" + ).newTagNode("crypto-algorithm").store()(dstGraph) + end if + if language == Languages.PYTHON || language == Languages.PYTHONSRC then + val known_crypto_libs = "(cryptography|Crypto|ecdsa|nacl).*" + atom.identifier.typeFullName(known_crypto_libs).newTagNode( + "crypto" + ).store()(dstGraph) + atom.call.methodFullName(known_crypto_libs).newTagNode( + "crypto" + ).store()(dstGraph) + atom.call.methodFullName( + s"${known_crypto_libs}(generate|encrypt|decrypt|derive|sign|public_bytes|private_bytes|exchange|new|update|export_key|import_key|from_string|from_pem|to_pem).*" + ).newTagNode( + "crypto-generate" + ).store()(dstGraph) + atom.call.name("[A-Z0-9]+").methodFullName( + s"${known_crypto_libs}(primitives|serialization).*" + ).argument.inCall.newTagNode( + "crypto-algorithm" + ).store()(dstGraph) + end run end EasyTagsPass diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/typerelations/AliasLinkerPass.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/typerelations/AliasLinkerPass.scala index 6088fb4d..75267f87 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/typerelations/AliasLinkerPass.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/typerelations/AliasLinkerPass.scala @@ -8,16 +8,16 @@ import io.appthreat.x2cpg.utils.LinkingUtil class AliasLinkerPass(cpg: Cpg) extends CpgPass(cpg) with LinkingUtil: - override def run(dstGraph: DiffGraphBuilder): Unit = - // Create ALIAS_OF edges from TYPE_DECL nodes to TYPE - linkToMultiple( - cpg, - srcLabels = List(NodeTypes.TYPE_DECL), - dstNodeLabel = NodeTypes.TYPE, - edgeType = EdgeTypes.ALIAS_OF, - dstNodeMap = typeFullNameToNode(cpg, _), - getDstFullNames = (srcNode: TypeDecl) => - srcNode.aliasTypeFullName, - dstFullNameKey = PropertyNames.ALIAS_TYPE_FULL_NAME, - dstGraph - ) + override def run(dstGraph: DiffGraphBuilder): Unit = + // Create ALIAS_OF edges from TYPE_DECL nodes to TYPE + linkToMultiple( + cpg, + srcLabels = List(NodeTypes.TYPE_DECL), + dstNodeLabel = NodeTypes.TYPE, + edgeType = EdgeTypes.ALIAS_OF, + dstNodeMap = typeFullNameToNode(cpg, _), + getDstFullNames = (srcNode: TypeDecl) => + srcNode.aliasTypeFullName, + dstFullNameKey = PropertyNames.ALIAS_TYPE_FULL_NAME, + dstGraph + ) diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/typerelations/TypeHierarchyPass.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/typerelations/TypeHierarchyPass.scala index a9a2ced5..1b7483c4 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/typerelations/TypeHierarchyPass.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/passes/typerelations/TypeHierarchyPass.scala @@ -10,19 +10,19 @@ import io.appthreat.x2cpg.utils.LinkingUtil */ class TypeHierarchyPass(cpg: Cpg) extends CpgPass(cpg) with LinkingUtil: - override def run(dstGraph: DiffGraphBuilder): Unit = - linkToMultiple( - cpg, - srcLabels = List(NodeTypes.TYPE_DECL), - dstNodeLabel = NodeTypes.TYPE, - edgeType = EdgeTypes.INHERITS_FROM, - dstNodeMap = typeFullNameToNode(cpg, _), - getDstFullNames = (srcNode: TypeDecl) => - if srcNode.inheritsFromTypeFullName != null then - srcNode.inheritsFromTypeFullName - else - Seq() - , - dstFullNameKey = PropertyNames.INHERITS_FROM_TYPE_FULL_NAME, - dstGraph - ) + override def run(dstGraph: DiffGraphBuilder): Unit = + linkToMultiple( + cpg, + srcLabels = List(NodeTypes.TYPE_DECL), + dstNodeLabel = NodeTypes.TYPE, + edgeType = EdgeTypes.INHERITS_FROM, + dstNodeMap = typeFullNameToNode(cpg, _), + getDstFullNames = (srcNode: TypeDecl) => + if srcNode.inheritsFromTypeFullName != null then + srcNode.inheritsFromTypeFullName + else + Seq() + , + dstFullNameKey = PropertyNames.INHERITS_FROM_TYPE_FULL_NAME, + dstGraph + ) diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/AstPropertiesUtil.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/AstPropertiesUtil.scala index 06f5e268..ce688226 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/AstPropertiesUtil.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/AstPropertiesUtil.scala @@ -5,26 +5,26 @@ import io.shiftleft.codepropertygraph.generated.PropertyNames object AstPropertiesUtil: - implicit class RootProperties(val ast: Ast) extends AnyVal: + implicit class RootProperties(val ast: Ast) extends AnyVal: - private def rootProperty(propertyName: String): Option[String] = - ast.root.flatMap(_.properties.get(propertyName).map(_.toString)) + private def rootProperty(propertyName: String): Option[String] = + ast.root.flatMap(_.properties.get(propertyName).map(_.toString)) - def rootType: Option[String] = rootProperty(PropertyNames.TYPE_FULL_NAME) + def rootType: Option[String] = rootProperty(PropertyNames.TYPE_FULL_NAME) - def rootCode: Option[String] = rootProperty(PropertyNames.CODE) + def rootCode: Option[String] = rootProperty(PropertyNames.CODE) - def rootName: Option[String] = rootProperty(PropertyNames.NAME) + def rootName: Option[String] = rootProperty(PropertyNames.NAME) - def rootCodeOrEmpty: String = rootCode.getOrElse("") + def rootCodeOrEmpty: String = rootCode.getOrElse("") - implicit class RootPropertiesOnSeq(val asts: Seq[Ast]) extends AnyVal: + implicit class RootPropertiesOnSeq(val asts: Seq[Ast]) extends AnyVal: - def rootType: Option[String] = asts.headOption.flatMap(_.rootType) + def rootType: Option[String] = asts.headOption.flatMap(_.rootType) - def rootCode: Option[String] = asts.headOption.flatMap(_.rootCode) + def rootCode: Option[String] = asts.headOption.flatMap(_.rootCode) - def rootName: Option[String] = asts.headOption.flatMap(_.rootName) + def rootName: Option[String] = asts.headOption.flatMap(_.rootName) - def rootCodeOrEmpty: String = asts.rootCode.getOrElse("") + def rootCodeOrEmpty: String = asts.rootCode.getOrElse("") end AstPropertiesUtil diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/Environment.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/Environment.scala index e9b9c638..eeaf456d 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/Environment.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/Environment.scala @@ -6,34 +6,34 @@ import java.nio.file.Paths object Environment: - object OperatingSystemType extends Enumeration: - type OperatingSystemType = Value + object OperatingSystemType extends Enumeration: + type OperatingSystemType = Value - val Windows, Linux, Mac, Unknown = Value + val Windows, Linux, Mac, Unknown = Value - object ArchitectureType extends Enumeration: - type ArchitectureType = Value + object ArchitectureType extends Enumeration: + type ArchitectureType = Value - val X86, ARM = Value + val X86, ARM = Value - lazy val operatingSystem: OperatingSystemType.OperatingSystemType = - if scala.util.Properties.isMac then OperatingSystemType.Mac - else if scala.util.Properties.isLinux then OperatingSystemType.Linux - else if scala.util.Properties.isWin then OperatingSystemType.Windows - else OperatingSystemType.Unknown + lazy val operatingSystem: OperatingSystemType.OperatingSystemType = + if scala.util.Properties.isMac then OperatingSystemType.Mac + else if scala.util.Properties.isLinux then OperatingSystemType.Linux + else if scala.util.Properties.isWin then OperatingSystemType.Windows + else OperatingSystemType.Unknown - lazy val architecture: ArchitectureType.ArchitectureType = - if scala.util.Properties.propOrNone("os.arch").contains("aarch64") then ArchitectureType.ARM - // We do not distinguish between x86 and x64. E.g, a 64 bit Windows will always lie about - // this and will report x86 anyway for backwards compatibility with 32 bit software. - else ArchitectureType.X86 + lazy val architecture: ArchitectureType.ArchitectureType = + if scala.util.Properties.propOrNone("os.arch").contains("aarch64") then ArchitectureType.ARM + // We do not distinguish between x86 and x64. E.g, a 64 bit Windows will always lie about + // this and will report x86 anyway for backwards compatibility with 32 bit software. + else ArchitectureType.X86 - private val logger = LoggerFactory.getLogger(getClass) + private val logger = LoggerFactory.getLogger(getClass) - def pathExists(path: String): Boolean = - if !Paths.get(path).toFile.exists() then - logger.debug(s"Input path '$path' does not exist!") - false - else - true + def pathExists(path: String): Boolean = + if !Paths.get(path).toFile.exists() then + logger.debug(s"Input path '$path' does not exist!") + false + else + true end Environment diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/ExternalCommand.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/ExternalCommand.scala index bb0264ff..42b97526 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/ExternalCommand.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/ExternalCommand.scala @@ -8,48 +8,48 @@ import scala.jdk.CollectionConverters.* object ExternalCommand: - private val IS_WIN: Boolean = - scala.util.Properties.isWin - - private val shellPrefix: Seq[String] = - if IS_WIN then "cmd" :: "/c" :: Nil else "sh" :: "-c" :: Nil - - def run(command: String, cwd: String, separateStdErr: Boolean = false): Try[Seq[String]] = - val stdOutOutput = new ConcurrentLinkedQueue[String] - val stdErrOutput = - if separateStdErr then new ConcurrentLinkedQueue[String] else stdOutOutput - val processLogger = ProcessLogger(stdOutOutput.add, stdErrOutput.add) - - Process(shellPrefix :+ command, new java.io.File(cwd)).!(processLogger) match - case 0 => - Success(stdOutOutput.asScala.toSeq) - case _ => - Failure(new RuntimeException(stdErrOutput.asScala.mkString(System.lineSeparator()))) - - private val COMMAND_AND: String = " && " - - def toOSCommand(command: String): String = if IS_WIN then command + ".cmd" else command - - def runMultiple( - command: String, - inDir: String = ".", - extraEnv: Map[String, String] = Map.empty - ): Try[String] = - val dir = new java.io.File(inDir) - val stdOutOutput = new ConcurrentLinkedQueue[String] - val stdErrOutput = new ConcurrentLinkedQueue[String] - val processLogger = ProcessLogger(stdOutOutput.add, stdErrOutput.add) - val commands = command.split(COMMAND_AND).toSeq - commands.map { cmd => - val cmdWithQuotesAroundDir = StringUtils.replace(cmd, inDir, s"'$inDir'") - Try(Process(cmdWithQuotesAroundDir, dir, extraEnv.toList*).!(processLogger)).getOrElse( - 1 - ) - }.sum match - case 0 => - Success(stdOutOutput.asScala.mkString(System.lineSeparator())) - case _ => - val allOutput = stdOutOutput.asScala ++ stdErrOutput.asScala - Failure(new RuntimeException(allOutput.mkString(System.lineSeparator()))) - end runMultiple + private val IS_WIN: Boolean = + scala.util.Properties.isWin + + private val shellPrefix: Seq[String] = + if IS_WIN then "cmd" :: "/c" :: Nil else "sh" :: "-c" :: Nil + + def run(command: String, cwd: String, separateStdErr: Boolean = false): Try[Seq[String]] = + val stdOutOutput = new ConcurrentLinkedQueue[String] + val stdErrOutput = + if separateStdErr then new ConcurrentLinkedQueue[String] else stdOutOutput + val processLogger = ProcessLogger(stdOutOutput.add, stdErrOutput.add) + + Process(shellPrefix :+ command, new java.io.File(cwd)).!(processLogger) match + case 0 => + Success(stdOutOutput.asScala.toSeq) + case _ => + Failure(new RuntimeException(stdErrOutput.asScala.mkString(System.lineSeparator()))) + + private val COMMAND_AND: String = " && " + + def toOSCommand(command: String): String = if IS_WIN then command + ".cmd" else command + + def runMultiple( + command: String, + inDir: String = ".", + extraEnv: Map[String, String] = Map.empty + ): Try[String] = + val dir = new java.io.File(inDir) + val stdOutOutput = new ConcurrentLinkedQueue[String] + val stdErrOutput = new ConcurrentLinkedQueue[String] + val processLogger = ProcessLogger(stdOutOutput.add, stdErrOutput.add) + val commands = command.split(COMMAND_AND).toSeq + commands.map { cmd => + val cmdWithQuotesAroundDir = StringUtils.replace(cmd, inDir, s"'$inDir'") + Try(Process(cmdWithQuotesAroundDir, dir, extraEnv.toList*).!(processLogger)).getOrElse( + 1 + ) + }.sum match + case 0 => + Success(stdOutOutput.asScala.mkString(System.lineSeparator())) + case _ => + val allOutput = stdOutOutput.asScala ++ stdErrOutput.asScala + Failure(new RuntimeException(allOutput.mkString(System.lineSeparator()))) + end runMultiple end ExternalCommand diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/HashUtil.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/HashUtil.scala index f9738d73..0176e2ad 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/HashUtil.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/HashUtil.scala @@ -6,26 +6,26 @@ import scala.util.Using object HashUtil: - def sha256(file: Path): String = - sha256(Seq(file)) + def sha256(file: Path): String = + sha256(Seq(file)) - def sha256(file: String): String = - sha256(Seq(Path.of(file))) + def sha256(file: String): String = + sha256(Seq(Path.of(file))) - def sha256(files: Seq[Path]): String = - val md = MessageDigest.getInstance("SHA-256") - val buffer = new Array[Byte](4096) - files - .filterNot(p => isDirectory(p.toRealPath())) - .foreach { path => - Using.resource(new DigestInputStream(Files.newInputStream(path), md)) { dis => - while dis.available() > 0 do - dis.read(buffer) - } + def sha256(files: Seq[Path]): String = + val md = MessageDigest.getInstance("SHA-256") + val buffer = new Array[Byte](4096) + files + .filterNot(p => isDirectory(p.toRealPath())) + .foreach { path => + Using.resource(new DigestInputStream(Files.newInputStream(path), md)) { dis => + while dis.available() > 0 do + dis.read(buffer) } - md.digest().map(b => String.format("%02x", Byte.box(b))).mkString + } + md.digest().map(b => String.format("%02x", Byte.box(b))).mkString - private def isDirectory(path: Path): Boolean = - if path == null || !Files.exists(path) then false - else Files.isDirectory(path) + private def isDirectory(path: Path): Boolean = + if path == null || !Files.exists(path) then false + else Files.isDirectory(path) end HashUtil diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/LinkingUtil.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/LinkingUtil.scala index 3c463556..4a0a984a 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/LinkingUtil.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/LinkingUtil.scala @@ -22,168 +22,168 @@ import scala.jdk.CollectionConverters.* trait LinkingUtil: - import overflowdb.BatchedUpdate.DiffGraphBuilder + import overflowdb.BatchedUpdate.DiffGraphBuilder - val logger: Logger = LoggerFactory.getLogger(classOf[LinkingUtil]) + val logger: Logger = LoggerFactory.getLogger(classOf[LinkingUtil]) - def typeDeclFullNameToNode(cpg: Cpg, x: String): Option[TypeDecl] = - nodesWithFullName(cpg, x).collectFirst { case x: TypeDecl => x } + def typeDeclFullNameToNode(cpg: Cpg, x: String): Option[TypeDecl] = + nodesWithFullName(cpg, x).collectFirst { case x: TypeDecl => x } - def typeFullNameToNode(cpg: Cpg, x: String): Option[Type] = - nodesWithFullName(cpg, x).collectFirst { case x: Type => x } + def typeFullNameToNode(cpg: Cpg, x: String): Option[Type] = + nodesWithFullName(cpg, x).collectFirst { case x: Type => x } - def methodFullNameToNode(cpg: Cpg, x: String): Option[Method] = - nodesWithFullName(cpg, x).collectFirst { case x: Method => x } + def methodFullNameToNode(cpg: Cpg, x: String): Option[Method] = + nodesWithFullName(cpg, x).collectFirst { case x: Method => x } - def namespaceBlockFullNameToNode(cpg: Cpg, x: String): Option[NamespaceBlock] = - nodesWithFullName(cpg, x).collectFirst { case x: NamespaceBlock => x } + def namespaceBlockFullNameToNode(cpg: Cpg, x: String): Option[NamespaceBlock] = + nodesWithFullName(cpg, x).collectFirst { case x: NamespaceBlock => x } - def nodesWithFullName(cpg: Cpg, x: String): mutable.Seq[NodeRef[? <: NodeDb]] = - cpg.graph.indexManager.lookup(PropertyNames.FULL_NAME, x).asScala + def nodesWithFullName(cpg: Cpg, x: String): mutable.Seq[NodeRef[? <: NodeDb]] = + cpg.graph.indexManager.lookup(PropertyNames.FULL_NAME, x).asScala - /** For all nodes `n` with a label in `srcLabels`, determine the value of `n.\$dstFullNameKey`, - * use that to lookup the destination node in `dstNodeMap`, and create an edge of type - * `edgeType` between `n` and the destination node. - */ - def linkToSingle( - cpg: Cpg, - srcLabels: List[String], - dstNodeLabel: String, - edgeType: String, - dstNodeMap: String => Option[StoredNode], - dstFullNameKey: String, - dstGraph: DiffGraphBuilder, - dstNotExistsHandler: Option[(StoredNode, String) => Unit] - ): Unit = - var loggedDeprecationWarning = false - val dereference = Dereference(cpg) - cpg.graph.nodes(srcLabels*).foreach { srcNode => - // If the source node does not have any outgoing edges of this type - // This check is just required for backward compatibility - if srcNode.outE(edgeType).isEmpty then - val key = new PropertyKey[String](dstFullNameKey) - srcNode - .propertyOption(key) - .filter { dstFullName => - val dereferenceDstFullName = - dereference.dereferenceTypeFullName(dstFullName) - srcNode.propertyDefaultValue(dstFullNameKey) != dereferenceDstFullName - } - .ifPresent { dstFullName => - // for `UNKNOWN` this is not always set, so we're using an Option here - val srcStoredNode = srcNode.asInstanceOf[StoredNode] - val dereferenceDstFullName = - dereference.dereferenceTypeFullName(dstFullName) - dstNodeMap(dereferenceDstFullName) match - case Some(dstNode) => - dstGraph.addEdge(srcStoredNode, dstNode, edgeType) - case None if dstNodeMap(dstFullName).isDefined => - dstGraph.addEdge( - srcStoredNode, - dstNodeMap(dstFullName).get, - edgeType - ) - case None if dstNotExistsHandler.isDefined => - dstNotExistsHandler.get(srcStoredNode, dereferenceDstFullName) - case _ => - logFailedDstLookup( - edgeType, - srcNode.label, - srcNode.id.toString, - dstNodeLabel, - dereferenceDstFullName - ) - } - else - srcNode.out(edgeType).property(Properties.FULL_NAME).nextOption() match - case Some(dstFullName) => - dstGraph.setNodeProperty( - srcNode.asInstanceOf[StoredNode], - dstFullNameKey, - dereference.dereferenceTypeFullName(dstFullName) - ) - case None => - logger.debug(s"Missing outgoing edge of type $edgeType from node $srcNode") - if !loggedDeprecationWarning then - logger.debug( - s"Using deprecated CPG format with already existing $edgeType edge between" + - s" a source node of type $srcLabels and a $dstNodeLabel node." - ) - loggedDeprecationWarning = true - } - end linkToSingle - - def linkToMultiple[SRC_NODE_TYPE <: StoredNode]( - cpg: Cpg, - srcLabels: List[String], - dstNodeLabel: String, - edgeType: String, - dstNodeMap: String => Option[StoredNode], - getDstFullNames: SRC_NODE_TYPE => Iterable[String], - dstFullNameKey: String, - dstGraph: DiffGraphBuilder - ): Unit = - var loggedDeprecationWarning = false - val dereference = Dereference(cpg) - cpg.graph.nodes(srcLabels*).asScala.cast[SRC_NODE_TYPE].foreach { srcNode => - if !srcNode.outE(edgeType).hasNext then - getDstFullNames(srcNode).foreach { dstFullName => - val dereferenceDstFullName = dereference.dereferenceTypeFullName(dstFullName) - dstNodeMap(dereferenceDstFullName) match - case Some(dstNode) => - dstGraph.addEdge(srcNode, dstNode, edgeType) - case None if dstNodeMap(dstFullName).isDefined => - dstGraph.addEdge(srcNode, dstNodeMap(dstFullName).get, edgeType) - case None => - logFailedDstLookup( - edgeType, - srcNode.label, - srcNode.id.toString, - dstNodeLabel, - dereferenceDstFullName - ) - } - else - val dstFullNames = srcNode.out(edgeType).property(Properties.FULL_NAME).l + /** For all nodes `n` with a label in `srcLabels`, determine the value of `n.\$dstFullNameKey`, + * use that to lookup the destination node in `dstNodeMap`, and create an edge of type `edgeType` + * between `n` and the destination node. + */ + def linkToSingle( + cpg: Cpg, + srcLabels: List[String], + dstNodeLabel: String, + edgeType: String, + dstNodeMap: String => Option[StoredNode], + dstFullNameKey: String, + dstGraph: DiffGraphBuilder, + dstNotExistsHandler: Option[(StoredNode, String) => Unit] + ): Unit = + var loggedDeprecationWarning = false + val dereference = Dereference(cpg) + cpg.graph.nodes(srcLabels*).foreach { srcNode => + // If the source node does not have any outgoing edges of this type + // This check is just required for backward compatibility + if srcNode.outE(edgeType).isEmpty then + val key = new PropertyKey[String](dstFullNameKey) + srcNode + .propertyOption(key) + .filter { dstFullName => + val dereferenceDstFullName = + dereference.dereferenceTypeFullName(dstFullName) + srcNode.propertyDefaultValue(dstFullNameKey) != dereferenceDstFullName + } + .ifPresent { dstFullName => + // for `UNKNOWN` this is not always set, so we're using an Option here + val srcStoredNode = srcNode.asInstanceOf[StoredNode] + val dereferenceDstFullName = + dereference.dereferenceTypeFullName(dstFullName) + dstNodeMap(dereferenceDstFullName) match + case Some(dstNode) => + dstGraph.addEdge(srcStoredNode, dstNode, edgeType) + case None if dstNodeMap(dstFullName).isDefined => + dstGraph.addEdge( + srcStoredNode, + dstNodeMap(dstFullName).get, + edgeType + ) + case None if dstNotExistsHandler.isDefined => + dstNotExistsHandler.get(srcStoredNode, dereferenceDstFullName) + case _ => + logFailedDstLookup( + edgeType, + srcNode.label, + srcNode.id.toString, + dstNodeLabel, + dereferenceDstFullName + ) + } + else + srcNode.out(edgeType).property(Properties.FULL_NAME).nextOption() match + case Some(dstFullName) => dstGraph.setNodeProperty( - srcNode, + srcNode.asInstanceOf[StoredNode], dstFullNameKey, - dstFullNames.map(dereference.dereferenceTypeFullName) + dereference.dereferenceTypeFullName(dstFullName) ) - if !loggedDeprecationWarning then - logger.debug( - s"Using deprecated CPG format with already existing $edgeType edge between" + - s" a source node of type $srcLabels and a $dstNodeLabel node." - ) - loggedDeprecationWarning = true - } - end linkToMultiple + case None => + logger.debug(s"Missing outgoing edge of type $edgeType from node $srcNode") + if !loggedDeprecationWarning then + logger.debug( + s"Using deprecated CPG format with already existing $edgeType edge between" + + s" a source node of type $srcLabels and a $dstNodeLabel node." + ) + loggedDeprecationWarning = true + } + end linkToSingle + + def linkToMultiple[SRC_NODE_TYPE <: StoredNode]( + cpg: Cpg, + srcLabels: List[String], + dstNodeLabel: String, + edgeType: String, + dstNodeMap: String => Option[StoredNode], + getDstFullNames: SRC_NODE_TYPE => Iterable[String], + dstFullNameKey: String, + dstGraph: DiffGraphBuilder + ): Unit = + var loggedDeprecationWarning = false + val dereference = Dereference(cpg) + cpg.graph.nodes(srcLabels*).asScala.cast[SRC_NODE_TYPE].foreach { srcNode => + if !srcNode.outE(edgeType).hasNext then + getDstFullNames(srcNode).foreach { dstFullName => + val dereferenceDstFullName = dereference.dereferenceTypeFullName(dstFullName) + dstNodeMap(dereferenceDstFullName) match + case Some(dstNode) => + dstGraph.addEdge(srcNode, dstNode, edgeType) + case None if dstNodeMap(dstFullName).isDefined => + dstGraph.addEdge(srcNode, dstNodeMap(dstFullName).get, edgeType) + case None => + logFailedDstLookup( + edgeType, + srcNode.label, + srcNode.id.toString, + dstNodeLabel, + dereferenceDstFullName + ) + } + else + val dstFullNames = srcNode.out(edgeType).property(Properties.FULL_NAME).l + dstGraph.setNodeProperty( + srcNode, + dstFullNameKey, + dstFullNames.map(dereference.dereferenceTypeFullName) + ) + if !loggedDeprecationWarning then + logger.debug( + s"Using deprecated CPG format with already existing $edgeType edge between" + + s" a source node of type $srcLabels and a $dstNodeLabel node." + ) + loggedDeprecationWarning = true + } + end linkToMultiple - @inline - protected def logFailedDstLookup( - edgeType: String, - srcNodeType: String, - srcNodeId: String, - dstNodeType: String, - dstFullName: String - ): Unit = - logger.debug( - "Could not create edge. Destination lookup failed. " + - s"edgeType=$edgeType, srcNodeType=$srcNodeType, srcNodeId=$srcNodeId, " + - s"dstNodeType=$dstNodeType, dstFullName=$dstFullName" - ) + @inline + protected def logFailedDstLookup( + edgeType: String, + srcNodeType: String, + srcNodeId: String, + dstNodeType: String, + dstFullName: String + ): Unit = + logger.debug( + "Could not create edge. Destination lookup failed. " + + s"edgeType=$edgeType, srcNodeType=$srcNodeType, srcNodeId=$srcNodeId, " + + s"dstNodeType=$dstNodeType, dstFullName=$dstFullName" + ) - @inline - protected def logFailedSrcLookup( - edgeType: String, - srcNodeType: String, - srcFullName: String, - dstNodeType: String, - dstNodeId: String - ): Unit = - logger.debug( - "Could not create edge. Source lookup failed. " + - s"edgeType=$edgeType, srcNodeType=$srcNodeType, srcFullName=$srcFullName, " + - s"dstNodeType=$dstNodeType, dstNodeId=$dstNodeId" - ) + @inline + protected def logFailedSrcLookup( + edgeType: String, + srcNodeType: String, + srcFullName: String, + dstNodeType: String, + dstNodeId: String + ): Unit = + logger.debug( + "Could not create edge. Source lookup failed. " + + s"edgeType=$edgeType, srcNodeType=$srcNodeType, srcFullName=$srcFullName, " + + s"dstNodeType=$dstNodeType, dstNodeId=$dstNodeId" + ) end LinkingUtil diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/ListUtils.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/ListUtils.scala index 582c9f70..57032c3a 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/ListUtils.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/ListUtils.scala @@ -1,15 +1,15 @@ package io.appthreat.x2cpg.utils object ListUtils: - extension [T](list: List[T]) - /** Return each element in the list up to and including the first element which satisfies - * the given predicate, or return the empty list if no matching element is found, e.g. - * - * List(1, 2, 3, 4, 1, 2).takeUntil(_ >= 3) => List(1, 2, 3) - * - * List(1, 2, 3, 4, 1, 2).takeUntil(_ >= 5) => Nil - */ - def takeUntil(predicate: T => Boolean): List[T] = - list.indexWhere(predicate) match - case index if index >= 0 => list.take(index + 1) - case _ => Nil + extension [T](list: List[T]) + /** Return each element in the list up to and including the first element which satisfies the + * given predicate, or return the empty list if no matching element is found, e.g. + * + * List(1, 2, 3, 4, 1, 2).takeUntil(_ >= 3) => List(1, 2, 3) + * + * List(1, 2, 3, 4, 1, 2).takeUntil(_ >= 5) => Nil + */ + def takeUntil(predicate: T => Boolean): List[T] = + list.indexWhere(predicate) match + case index if index >= 0 => list.take(index + 1) + case _ => Nil diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/NodeBuilders.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/NodeBuilders.scala index e57f0b6f..e066eb81 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/NodeBuilders.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/NodeBuilders.scala @@ -22,162 +22,162 @@ import io.shiftleft.codepropertygraph.generated.{DispatchTypes, EvaluationStrate */ object NodeBuilders: - private def composeCallSignature(returnType: String, argumentTypes: Iterable[String]): String = - s"$returnType(${argumentTypes.mkString(",")})" - - private def composeMethodFullName( - typeDeclFullName: Option[String], - name: String, - signature: String - ) = - val typeDeclPrefix = typeDeclFullName.map(maybeName => s"$maybeName.").getOrElse("") - s"$typeDeclPrefix$name:$signature" - - def newAnnotationLiteralNode(name: String): NewAnnotationLiteral = - NewAnnotationLiteral() - .name(name) - .code(name) - - def newBindingNode(name: String, signature: String, methodFullName: String): NewBinding = - NewBinding() - .name(name) - .methodFullName(methodFullName) - .signature(signature) - - def newLocalNode( - name: String, - typeFullName: String, - closureBindingId: Option[String] = None - ): NewLocal = - NewLocal() - .code(name) - .name(name) - .typeFullName(typeFullName) - .closureBindingId(closureBindingId) - - def newClosureBindingNode( - closureBindingId: String, - originalName: String, - evaluationStrategy: String - ): NewClosureBinding = - NewClosureBinding() - .closureBindingId(closureBindingId) - .closureOriginalName(originalName) - .evaluationStrategy(evaluationStrategy) - - def newCallNode( - methodName: String, - typeDeclFullName: Option[String], - returnTypeFullName: String, - dispatchType: String, - argumentTypes: Iterable[String] = Nil, - code: String = PropertyDefaults.Code, - lineNumber: Option[Integer] = None, - columnNumber: Option[Integer] = None - ): NewCall = - val signature = composeCallSignature(returnTypeFullName, argumentTypes) - val methodFullName = composeMethodFullName(typeDeclFullName, methodName, signature) - NewCall() - .name(methodName) - .methodFullName(methodFullName) - .signature(signature) - .typeFullName(returnTypeFullName) - .dispatchType(dispatchType) - .code(code) - .lineNumber(lineNumber) - .columnNumber(columnNumber) - end newCallNode - - def newDependencyNode(name: String, groupId: String, version: String): NewDependency = - NewDependency() - .name(name) - .dependencyGroupId(groupId) - .version(version) - - def newFieldIdentifierNode( - name: String, - line: Option[Integer] = None, - column: Option[Integer] = None - ): NewFieldIdentifier = - NewFieldIdentifier() - .canonicalName(name) - .code(name) - .lineNumber(line) - .columnNumber(column) - - def newModifierNode(modifierType: String): NewModifier = - NewModifier().modifierType(modifierType) - - def newIdentifierNode( - name: String, - typeFullName: String, - dynamicTypeHints: Seq[String] = Seq() - ): NewIdentifier = - newIdentifierNode(name, typeFullName, dynamicTypeHints, None) - - def newIdentifierNode( - name: String, - typeFullName: String, - dynamicTypeHints: Seq[String], - line: Option[Integer] - ): NewIdentifier = - NewIdentifier() - .code(name) - .name(name) - .typeFullName(typeFullName) - .dynamicTypeHintFullName(dynamicTypeHints) - .lineNumber(line) - - def newOperatorCallNode( - name: String, - code: String, - typeFullName: Option[String] = None, - line: Option[Integer] = None, - column: Option[Integer] = None - ): NewCall = - NewCall() - .name(name) - .methodFullName(name) - .code(code) - .signature("") - .dispatchType(DispatchTypes.STATIC_DISPATCH) - .typeFullName(typeFullName.getOrElse("ANY")) - .lineNumber(line) - .columnNumber(column) - - def newThisParameterNode( - name: String = "this", - code: String = "this", - typeFullName: String, - dynamicTypeHintFullName: Seq[String] = Seq.empty, - line: Option[Integer] = None, - column: Option[Integer] = None, - evaluationStrategy: String = EvaluationStrategies.BY_SHARING - ): NewMethodParameterIn = - NewMethodParameterIn() - .name(name) - .code(code) - .lineNumber(line) - .columnNumber(column) - .dynamicTypeHintFullName(dynamicTypeHintFullName) - .evaluationStrategy(evaluationStrategy) - .typeFullName(typeFullName) - .index(0) - .order(0) - - /** Create a method return node - */ - def newMethodReturnNode( - typeFullName: String, - dynamicTypeHintFullName: Option[String] = None, - line: Option[Integer], - column: Option[Integer] - ): NewMethodReturn = - NewMethodReturn() - .typeFullName(typeFullName) - .dynamicTypeHintFullName(dynamicTypeHintFullName) - .code("RET") - .evaluationStrategy(EvaluationStrategies.BY_VALUE) - .lineNumber(line) - .columnNumber(column) + private def composeCallSignature(returnType: String, argumentTypes: Iterable[String]): String = + s"$returnType(${argumentTypes.mkString(",")})" + + private def composeMethodFullName( + typeDeclFullName: Option[String], + name: String, + signature: String + ) = + val typeDeclPrefix = typeDeclFullName.map(maybeName => s"$maybeName.").getOrElse("") + s"$typeDeclPrefix$name:$signature" + + def newAnnotationLiteralNode(name: String): NewAnnotationLiteral = + NewAnnotationLiteral() + .name(name) + .code(name) + + def newBindingNode(name: String, signature: String, methodFullName: String): NewBinding = + NewBinding() + .name(name) + .methodFullName(methodFullName) + .signature(signature) + + def newLocalNode( + name: String, + typeFullName: String, + closureBindingId: Option[String] = None + ): NewLocal = + NewLocal() + .code(name) + .name(name) + .typeFullName(typeFullName) + .closureBindingId(closureBindingId) + + def newClosureBindingNode( + closureBindingId: String, + originalName: String, + evaluationStrategy: String + ): NewClosureBinding = + NewClosureBinding() + .closureBindingId(closureBindingId) + .closureOriginalName(originalName) + .evaluationStrategy(evaluationStrategy) + + def newCallNode( + methodName: String, + typeDeclFullName: Option[String], + returnTypeFullName: String, + dispatchType: String, + argumentTypes: Iterable[String] = Nil, + code: String = PropertyDefaults.Code, + lineNumber: Option[Integer] = None, + columnNumber: Option[Integer] = None + ): NewCall = + val signature = composeCallSignature(returnTypeFullName, argumentTypes) + val methodFullName = composeMethodFullName(typeDeclFullName, methodName, signature) + NewCall() + .name(methodName) + .methodFullName(methodFullName) + .signature(signature) + .typeFullName(returnTypeFullName) + .dispatchType(dispatchType) + .code(code) + .lineNumber(lineNumber) + .columnNumber(columnNumber) + end newCallNode + + def newDependencyNode(name: String, groupId: String, version: String): NewDependency = + NewDependency() + .name(name) + .dependencyGroupId(groupId) + .version(version) + + def newFieldIdentifierNode( + name: String, + line: Option[Integer] = None, + column: Option[Integer] = None + ): NewFieldIdentifier = + NewFieldIdentifier() + .canonicalName(name) + .code(name) + .lineNumber(line) + .columnNumber(column) + + def newModifierNode(modifierType: String): NewModifier = + NewModifier().modifierType(modifierType) + + def newIdentifierNode( + name: String, + typeFullName: String, + dynamicTypeHints: Seq[String] = Seq() + ): NewIdentifier = + newIdentifierNode(name, typeFullName, dynamicTypeHints, None) + + def newIdentifierNode( + name: String, + typeFullName: String, + dynamicTypeHints: Seq[String], + line: Option[Integer] + ): NewIdentifier = + NewIdentifier() + .code(name) + .name(name) + .typeFullName(typeFullName) + .dynamicTypeHintFullName(dynamicTypeHints) + .lineNumber(line) + + def newOperatorCallNode( + name: String, + code: String, + typeFullName: Option[String] = None, + line: Option[Integer] = None, + column: Option[Integer] = None + ): NewCall = + NewCall() + .name(name) + .methodFullName(name) + .code(code) + .signature("") + .dispatchType(DispatchTypes.STATIC_DISPATCH) + .typeFullName(typeFullName.getOrElse("ANY")) + .lineNumber(line) + .columnNumber(column) + + def newThisParameterNode( + name: String = "this", + code: String = "this", + typeFullName: String, + dynamicTypeHintFullName: Seq[String] = Seq.empty, + line: Option[Integer] = None, + column: Option[Integer] = None, + evaluationStrategy: String = EvaluationStrategies.BY_SHARING + ): NewMethodParameterIn = + NewMethodParameterIn() + .name(name) + .code(code) + .lineNumber(line) + .columnNumber(column) + .dynamicTypeHintFullName(dynamicTypeHintFullName) + .evaluationStrategy(evaluationStrategy) + .typeFullName(typeFullName) + .index(0) + .order(0) + + /** Create a method return node + */ + def newMethodReturnNode( + typeFullName: String, + dynamicTypeHintFullName: Option[String] = None, + line: Option[Integer], + column: Option[Integer] + ): NewMethodReturn = + NewMethodReturn() + .typeFullName(typeFullName) + .dynamicTypeHintFullName(dynamicTypeHintFullName) + .code("RET") + .evaluationStrategy(EvaluationStrategies.BY_VALUE) + .lineNumber(line) + .columnNumber(column) end NodeBuilders diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/Report.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/Report.scala index 6546f5c0..ac82873a 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/Report.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/Report.scala @@ -6,80 +6,80 @@ import scala.collection.concurrent.TrieMap object Report: - private val logger = LoggerFactory.getLogger(Report.getClass) + private val logger = LoggerFactory.getLogger(Report.getClass) - private type FileName = String + private type FileName = String - private type Reports = TrieMap[FileName, ReportEntry] + private type Reports = TrieMap[FileName, ReportEntry] - private case class ReportEntry(loc: Int, parsed: Boolean, cpgGen: Boolean, duration: Long): - def toSeq: Seq[String] = - val lines = loc.toString - val dur = if duration == 0 then "-" else TimeUtils.pretty(duration) - val wasParsed = if parsed then "yes" else "no" - val gotCpg = if cpgGen then "yes" else "no" - Seq(lines, wasParsed, gotCpg, dur) + private case class ReportEntry(loc: Int, parsed: Boolean, cpgGen: Boolean, duration: Long): + def toSeq: Seq[String] = + val lines = loc.toString + val dur = if duration == 0 then "-" else TimeUtils.pretty(duration) + val wasParsed = if parsed then "yes" else "no" + val gotCpg = if cpgGen then "yes" else "no" + Seq(lines, wasParsed, gotCpg, dur) class Report: - import Report.* + import Report.* - private val reports: Reports = TrieMap.empty + private val reports: Reports = TrieMap.empty - private def formatTable(table: Seq[Seq[String]]): String = - if table.isEmpty then "" - else - // Get column widths based on the maximum cell width in each column (+2 for a one character padding on each side) - val colWidths = - table.transpose.map(_.map(cell => if cell == null then 0 else cell.length).max + 2) - // Format each row - val rows = table.map( - _.zip(colWidths) - .map { case (item, size) => s" %-${size - 1}s".format(item) } - .mkString("|", "|", "|") - ) - // Formatted separator row, used to separate the header and draw table borders - val separator = colWidths.map("-" * _).mkString("+", "+", "+") - // Put the table together and return - val header = rows.head - val content = rows.tail.take(rows.tail.size - 1) - val footer = rows.tail.last - (separator +: header +: separator +: content :+ separator :+ footer :+ separator) - .mkString("\n") - - def print(): Unit = - val rows = reports.toSeq - .sortBy(_._1) - .zipWithIndex - .view - .map { case ((file, sum), index) => - s"${index + 1}" +: file +: sum.toSeq - } - .toSeq - val numOfReports = reports.size - val header = Seq(Seq("#", "File", "LOC", "Parsed", "Got a CPG", "Duration")) - val footer = Seq( - Seq( - "Total", - "", - s"${reports.map(_._2.loc).sum}", - s"${reports.count(_._2.parsed)}/$numOfReports", - s"${reports.count(_._2.cpgGen)}/$numOfReports", - "" - ) + private def formatTable(table: Seq[Seq[String]]): String = + if table.isEmpty then "" + else + // Get column widths based on the maximum cell width in each column (+2 for a one character padding on each side) + val colWidths = + table.transpose.map(_.map(cell => if cell == null then 0 else cell.length).max + 2) + // Format each row + val rows = table.map( + _.zip(colWidths) + .map { case (item, size) => s" %-${size - 1}s".format(item) } + .mkString("|", "|", "|") ) - val table = header ++ rows ++ footer - logger.debug(s"Report:${System.lineSeparator()}${formatTable(table)}") - end print + // Formatted separator row, used to separate the header and draw table borders + val separator = colWidths.map("-" * _).mkString("+", "+", "+") + // Put the table together and return + val header = rows.head + val content = rows.tail.take(rows.tail.size - 1) + val footer = rows.tail.last + (separator +: header +: separator +: content :+ separator :+ footer :+ separator) + .mkString("\n") + + def print(): Unit = + val rows = reports.toSeq + .sortBy(_._1) + .zipWithIndex + .view + .map { case ((file, sum), index) => + s"${index + 1}" +: file +: sum.toSeq + } + .toSeq + val numOfReports = reports.size + val header = Seq(Seq("#", "File", "LOC", "Parsed", "Got a CPG", "Duration")) + val footer = Seq( + Seq( + "Total", + "", + s"${reports.map(_._2.loc).sum}", + s"${reports.count(_._2.parsed)}/$numOfReports", + s"${reports.count(_._2.cpgGen)}/$numOfReports", + "" + ) + ) + val table = header ++ rows ++ footer + logger.debug(s"Report:${System.lineSeparator()}${formatTable(table)}") + end print - def addReportInfo( - fileName: FileName, - loc: Int, - parsed: Boolean = false, - cpgGen: Boolean = false, - duration: Long = 0 - ): Unit = reports(fileName) = ReportEntry(loc, parsed, cpgGen, duration) + def addReportInfo( + fileName: FileName, + loc: Int, + parsed: Boolean = false, + cpgGen: Boolean = false, + duration: Long = 0 + ): Unit = reports(fileName) = ReportEntry(loc, parsed, cpgGen, duration) - def updateReport(fileName: FileName, cpg: Boolean, duration: Long): Unit = - reports.updateWith(fileName)(_.map(_.copy(cpgGen = cpg, duration = duration))) + def updateReport(fileName: FileName, cpg: Boolean, duration: Long): Unit = + reports.updateWith(fileName)(_.map(_.copy(cpgGen = cpg, duration = duration))) end Report diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/StringUtils.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/StringUtils.scala index 2551d443..da7b9094 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/StringUtils.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/StringUtils.scala @@ -1,5 +1,5 @@ package io.appthreat.x2cpg.utils implicit class StringUtils(str: String): - def isAllUpperCase: Boolean = - str.forall(c => c.isUpper || !c.isLetter) + def isAllUpperCase: Boolean = + str.forall(c => c.isUpper || !c.isLetter) diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/TimeUtils.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/TimeUtils.scala index 86922831..f9ad4cb5 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/TimeUtils.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/TimeUtils.scala @@ -9,51 +9,51 @@ import scala.util.Try object TimeUtils: - /** Measures elapsed time for executing a block in nanoseconds */ - def time[R](block: => R): (R, Long) = - val t0 = System.nanoTime() - val result = block - val t1 = System.nanoTime() - val elapsed = t1 - t0 - (result, elapsed) - - /** Selects most appropriate TimeUnit for given duration and formats it accordingly */ - def pretty(duration: Long): String = pretty(Duration.fromNanos(duration)) - - def runWithTimeout[T](timeoutMs: Long)(f: => T): Try[T] = - Try(Await.result(Future(f), timeoutMs milliseconds)) - - private def pretty(duration: Duration): String = - duration match - case d: FiniteDuration => - val nanos = d.toNanos - val unit = chooseUnit(nanos) - val value = nanos.toDouble / NANOSECONDS.convert(1, unit) - - s"%.4g %s".formatLocal(Locale.ROOT, value, abbreviate(unit)) - - case Duration.MinusInf => s"-∞ (minus infinity)" - case Duration.Inf => s"∞ (infinity)" - case _ => "undefined" - - private def chooseUnit(nanos: Long): TimeUnit = - val d = nanos.nanos - - if d.toDays > 0 then DAYS - else if d.toHours > 0 then HOURS - else if d.toMinutes > 0 then MINUTES - else if d.toSeconds > 0 then SECONDS - else if d.toMillis > 0 then MILLISECONDS - else if d.toMicros > 0 then MICROSECONDS - else NANOSECONDS - - private def abbreviate(unit: TimeUnit): String = - unit match - case NANOSECONDS => "ns" - case MICROSECONDS => "μs" - case MILLISECONDS => "ms" - case SECONDS => "s" - case MINUTES => "min" - case HOURS => "h" - case DAYS => "d" + /** Measures elapsed time for executing a block in nanoseconds */ + def time[R](block: => R): (R, Long) = + val t0 = System.nanoTime() + val result = block + val t1 = System.nanoTime() + val elapsed = t1 - t0 + (result, elapsed) + + /** Selects most appropriate TimeUnit for given duration and formats it accordingly */ + def pretty(duration: Long): String = pretty(Duration.fromNanos(duration)) + + def runWithTimeout[T](timeoutMs: Long)(f: => T): Try[T] = + Try(Await.result(Future(f), timeoutMs milliseconds)) + + private def pretty(duration: Duration): String = + duration match + case d: FiniteDuration => + val nanos = d.toNanos + val unit = chooseUnit(nanos) + val value = nanos.toDouble / NANOSECONDS.convert(1, unit) + + s"%.4g %s".formatLocal(Locale.ROOT, value, abbreviate(unit)) + + case Duration.MinusInf => s"-∞ (minus infinity)" + case Duration.Inf => s"∞ (infinity)" + case _ => "undefined" + + private def chooseUnit(nanos: Long): TimeUnit = + val d = nanos.nanos + + if d.toDays > 0 then DAYS + else if d.toHours > 0 then HOURS + else if d.toMinutes > 0 then MINUTES + else if d.toSeconds > 0 then SECONDS + else if d.toMillis > 0 then MILLISECONDS + else if d.toMicros > 0 then MICROSECONDS + else NANOSECONDS + + private def abbreviate(unit: TimeUnit): String = + unit match + case NANOSECONDS => "ns" + case MICROSECONDS => "μs" + case MILLISECONDS => "ms" + case SECONDS => "s" + case MINUTES => "min" + case HOURS => "h" + case DAYS => "d" end TimeUtils diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/dependency/DependencyResolver.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/dependency/DependencyResolver.scala index ac8ada38..ea42850a 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/dependency/DependencyResolver.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/dependency/DependencyResolver.scala @@ -9,118 +9,118 @@ import java.nio.file.Path import scala.util.{Failure, Success} object GradleConfigKeys extends Enumeration: - type GradleConfigKey = Value - val ProjectName, ConfigurationName = Value + type GradleConfigKey = Value + val ProjectName, ConfigurationName = Value case class DependencyResolverParams( forMaven: Map[String, String] = Map(), forGradle: Map[GradleConfigKey, String] = Map() ) object DependencyResolver: - private val logger = LoggerFactory.getLogger(getClass) - private val defaultGradleProjectName = "app" - private val defaultGradleConfigurationName = "compileClasspath" - private val MaxSearchDepth: Int = 4 - - def getCoordinates( - projectDir: Path, - params: DependencyResolverParams = new DependencyResolverParams - ): Option[collection.Seq[String]] = - val coordinates = findSupportedBuildFiles(projectDir).flatMap { buildFile => - if isMavenBuildFile(buildFile) then - // TODO: implement - None - else if isGradleBuildFile(buildFile) then - getCoordinatesForGradleProject(buildFile.getParent, defaultGradleConfigurationName) - else - logger.debug(s"Found unsupported build file $buildFile") - Nil - }.flatten - - Option.when(coordinates.nonEmpty)(coordinates) - - private def getCoordinatesForGradleProject( - projectDir: Path, - configuration: String - ): Option[collection.Seq[String]] = - val lines = ExternalCommand.run( - s"gradle dependencies --configuration $configuration", - projectDir.toString - ) match - case Success(lines) => lines - case Failure(exception) => - logger.debug( - s"Could not retrieve dependencies for Gradle project at path `$projectDir`\n" + - exception.getMessage - ) - Seq() - - val coordinates = MavenCoordinates.fromGradleOutput(lines) - logger.debug("Got {} Maven coordinates", coordinates.size) - Some(coordinates) - end getCoordinatesForGradleProject - - def getDependencies( - projectDir: Path, - params: DependencyResolverParams = new DependencyResolverParams - ): Option[collection.Seq[String]] = - val dependencies = findSupportedBuildFiles(projectDir).flatMap { buildFile => - if isMavenBuildFile(buildFile) then - MavenDependencies.get(buildFile.getParent) - else if isGradleBuildFile(buildFile) then - getDepsForGradleProject(params, buildFile.getParent) - else - logger.debug(s"Found unsupported build file $buildFile") - Nil - }.flatten - - Option.when(dependencies.nonEmpty)(dependencies) - - private def getDepsForGradleProject( - params: DependencyResolverParams, - projectDir: Path - ): Option[collection.Seq[String]] = - logger.debug("resolving Gradle dependencies at {}", projectDir) - val gradleProjectName = - params.forGradle.getOrElse(GradleConfigKeys.ProjectName, defaultGradleProjectName) - val gradleConfiguration = - params.forGradle.getOrElse( - GradleConfigKeys.ConfigurationName, - defaultGradleConfigurationName - ) - GradleDependencies.get(projectDir, gradleProjectName, gradleConfiguration) match - case Some(deps) => Some(deps) - case None => - logger.debug( - s"Could not download Gradle dependencies for project at path `$projectDir`" - ) - None - end getDepsForGradleProject - - private def isGradleBuildFile(file: File): Boolean = - val pathString = file.pathAsString - pathString.endsWith(".gradle") || pathString.endsWith(".gradle.kts") - - private def isMavenBuildFile(file: File): Boolean = - file.pathAsString.endsWith("pom.xml") - - private def findSupportedBuildFiles(currentDir: File, depth: Int = 0): List[Path] = - if depth >= MaxSearchDepth then - logger.debug("findSupportedBuildFiles reached max depth before finding build files") - Nil + private val logger = LoggerFactory.getLogger(getClass) + private val defaultGradleProjectName = "app" + private val defaultGradleConfigurationName = "compileClasspath" + private val MaxSearchDepth: Int = 4 + + def getCoordinates( + projectDir: Path, + params: DependencyResolverParams = new DependencyResolverParams + ): Option[collection.Seq[String]] = + val coordinates = findSupportedBuildFiles(projectDir).flatMap { buildFile => + if isMavenBuildFile(buildFile) then + // TODO: implement + None + else if isGradleBuildFile(buildFile) then + getCoordinatesForGradleProject(buildFile.getParent, defaultGradleConfigurationName) else - val (childDirectories, childFiles) = currentDir.children.partition(_.isDirectory) - // Only fetch dependencies once for projects with both a build.gradle and a pom.xml file - val childFileList = childFiles.toList - childFileList - .find(isGradleBuildFile) - .orElse(childFileList.find(isMavenBuildFile)) match - case Some(buildFile) => buildFile.path :: Nil - - case None if childDirectories.isEmpty => Nil - - case None => - childDirectories.flatMap { dir => - findSupportedBuildFiles(dir, depth + 1) - }.toList + logger.debug(s"Found unsupported build file $buildFile") + Nil + }.flatten + + Option.when(coordinates.nonEmpty)(coordinates) + + private def getCoordinatesForGradleProject( + projectDir: Path, + configuration: String + ): Option[collection.Seq[String]] = + val lines = ExternalCommand.run( + s"gradle dependencies --configuration $configuration", + projectDir.toString + ) match + case Success(lines) => lines + case Failure(exception) => + logger.debug( + s"Could not retrieve dependencies for Gradle project at path `$projectDir`\n" + + exception.getMessage + ) + Seq() + + val coordinates = MavenCoordinates.fromGradleOutput(lines) + logger.debug("Got {} Maven coordinates", coordinates.size) + Some(coordinates) + end getCoordinatesForGradleProject + + def getDependencies( + projectDir: Path, + params: DependencyResolverParams = new DependencyResolverParams + ): Option[collection.Seq[String]] = + val dependencies = findSupportedBuildFiles(projectDir).flatMap { buildFile => + if isMavenBuildFile(buildFile) then + MavenDependencies.get(buildFile.getParent) + else if isGradleBuildFile(buildFile) then + getDepsForGradleProject(params, buildFile.getParent) + else + logger.debug(s"Found unsupported build file $buildFile") + Nil + }.flatten + + Option.when(dependencies.nonEmpty)(dependencies) + + private def getDepsForGradleProject( + params: DependencyResolverParams, + projectDir: Path + ): Option[collection.Seq[String]] = + logger.debug("resolving Gradle dependencies at {}", projectDir) + val gradleProjectName = + params.forGradle.getOrElse(GradleConfigKeys.ProjectName, defaultGradleProjectName) + val gradleConfiguration = + params.forGradle.getOrElse( + GradleConfigKeys.ConfigurationName, + defaultGradleConfigurationName + ) + GradleDependencies.get(projectDir, gradleProjectName, gradleConfiguration) match + case Some(deps) => Some(deps) + case None => + logger.debug( + s"Could not download Gradle dependencies for project at path `$projectDir`" + ) + None + end getDepsForGradleProject + + private def isGradleBuildFile(file: File): Boolean = + val pathString = file.pathAsString + pathString.endsWith(".gradle") || pathString.endsWith(".gradle.kts") + + private def isMavenBuildFile(file: File): Boolean = + file.pathAsString.endsWith("pom.xml") + + private def findSupportedBuildFiles(currentDir: File, depth: Int = 0): List[Path] = + if depth >= MaxSearchDepth then + logger.debug("findSupportedBuildFiles reached max depth before finding build files") + Nil + else + val (childDirectories, childFiles) = currentDir.children.partition(_.isDirectory) + // Only fetch dependencies once for projects with both a build.gradle and a pom.xml file + val childFileList = childFiles.toList + childFileList + .find(isGradleBuildFile) + .orElse(childFileList.find(isMavenBuildFile)) match + case Some(buildFile) => buildFile.path :: Nil + + case None if childDirectories.isEmpty => Nil + + case None => + childDirectories.flatMap { dir => + findSupportedBuildFiles(dir, depth + 1) + }.toList end DependencyResolver diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/dependency/GradleDependencies.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/dependency/GradleDependencies.scala index 01af3541..9aa5a3a0 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/dependency/GradleDependencies.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/dependency/GradleDependencies.scala @@ -18,39 +18,39 @@ case class GradleProjectInfo( tasks: Seq[String], hasAndroidSubproject: Boolean = false ): - def gradleVersionMajorMinor(): (Int, Int) = - def isValidPart(part: String) = part.forall(Character.isDigit) - val parts = gradleVersion.split('.') - if parts.length == 1 && isValidPart(parts(0)) then - (parts(0).toInt, 0) - else if parts.length >= 2 && isValidPart(parts(0)) && isValidPart(parts(1)) then - (parts(0).toInt, parts(1).toInt) - else - (-1, -1) + def gradleVersionMajorMinor(): (Int, Int) = + def isValidPart(part: String) = part.forall(Character.isDigit) + val parts = gradleVersion.split('.') + if parts.length == 1 && isValidPart(parts(0)) then + (parts(0).toInt, 0) + else if parts.length >= 2 && isValidPart(parts(0)) && isValidPart(parts(1)) then + (parts(0).toInt, parts(1).toInt) + else + (-1, -1) object Constants: - val aarFileExtension = "aar" - val gradleAndroidPropertyPrefix = "android." - val gradlePropertiesTaskName = "properties" - val jarInsideAarFileName = "classes.jar" + val aarFileExtension = "aar" + val gradleAndroidPropertyPrefix = "android." + val gradlePropertiesTaskName = "properties" + val jarInsideAarFileName = "classes.jar" case class GradleDepsInitScript(contents: String, taskName: String, destinationDir: Path) object GradleDependencies: - private val logger = LoggerFactory.getLogger(getClass) - private val initScriptPrefix = "x2cpg.init.gradle" - private val taskNamePrefix = "x2cpgCopyDeps" - private val tempDirPrefix = "x2cpgDependencies" + private val logger = LoggerFactory.getLogger(getClass) + private val initScriptPrefix = "x2cpg.init.gradle" + private val taskNamePrefix = "x2cpgCopyDeps" + private val tempDirPrefix = "x2cpgDependencies" - // works with Gradle 5.1+ because the script makes use of `task.register`: - // https://docs.gradle.org/current/userguide/task_configuration_avoidance.html - private def gradle5OrLaterAndroidInitScript( - taskName: String, - destination: String, - gradleProjectName: String, - gradleConfigurationName: String - ): String = - s""" + // works with Gradle 5.1+ because the script makes use of `task.register`: + // https://docs.gradle.org/current/userguide/task_configuration_avoidance.html + private def gradle5OrLaterAndroidInitScript( + taskName: String, + destination: String, + gradleProjectName: String, + gradleConfigurationName: String + ): String = + s""" |allprojects { | afterEvaluate { project -> | def taskName = "$taskName" @@ -89,10 +89,10 @@ object GradleDependencies: |} |""".stripMargin - // this init script _should_ work with Gradle 4-8, but has not been tested thoroughly - // TODO: add test cases for older Gradle versions - private def gradle5OrLaterInitScript(taskName: String, destination: String): String = - s""" + // this init script _should_ work with Gradle 4-8, but has not been tested thoroughly + // TODO: add test cases for older Gradle versions + private def gradle5OrLaterInitScript(taskName: String, destination: String): String = + s""" |allprojects { | apply plugin: 'java' | task $taskName(type: Copy) { @@ -102,233 +102,233 @@ object GradleDependencies: |} |""".stripMargin - private def makeInitScript( - destinationDir: Path, - forAndroid: Boolean, - gradleProjectName: String, - gradleConfigurationName: String - ): GradleDepsInitScript = - val taskName = taskNamePrefix + "_" + (Random.alphanumeric.take(8)).toList.mkString - val content = - if forAndroid then - gradle5OrLaterAndroidInitScript( - taskName, - destinationDir.toString, - gradleProjectName, - gradleConfigurationName - ) - else - gradle5OrLaterInitScript(taskName, destinationDir.toString) - GradleDepsInitScript(content, taskName, destinationDir) + private def makeInitScript( + destinationDir: Path, + forAndroid: Boolean, + gradleProjectName: String, + gradleConfigurationName: String + ): GradleDepsInitScript = + val taskName = taskNamePrefix + "_" + (Random.alphanumeric.take(8)).toList.mkString + val content = + if forAndroid then + gradle5OrLaterAndroidInitScript( + taskName, + destinationDir.toString, + gradleProjectName, + gradleConfigurationName + ) + else + gradle5OrLaterInitScript(taskName, destinationDir.toString) + GradleDepsInitScript(content, taskName, destinationDir) - private[dependency] def makeConnection(projectDir: JFile): ProjectConnection = - GradleConnector.newConnector().forProjectDirectory(projectDir).connect() + private[dependency] def makeConnection(projectDir: JFile): ProjectConnection = + GradleConnector.newConnector().forProjectDirectory(projectDir).connect() - private def getGradleProjectInfo( - projectDir: Path, - projectName: String - ): Option[GradleProjectInfo] = - Try(makeConnection(projectDir.toFile)) match - case Success(gradleConnection) => - Using.resource(gradleConnection) { connection => - try - val buildEnv = - connection.getModel[BuildEnvironment](classOf[BuildEnvironment]) - val project = connection.getModel[GradleProject](classOf[GradleProject]) - val hasAndroidPrefixGradleProperty = - runGradleTask(connection, Constants.gradlePropertiesTaskName) match - case Some(out) => - out.split('\n').exists( - _.startsWith(Constants.gradleAndroidPropertyPrefix) - ) - case None => false - val info = GradleProjectInfo( - buildEnv.getGradle.getGradleVersion, - project.getTasks.asScala.map(_.getName).toSeq, - hasAndroidPrefixGradleProperty - ) - if hasAndroidPrefixGradleProperty then - val validProjectNames = List( - project.getName - ) ++ project.getChildren.getAll.asScala.map(_.getName) - logger.debug( - s"Found Gradle projects: ${validProjectNames.mkString(",")}" - ) - if !validProjectNames.contains(projectName) then - val validProjectNamesStr = validProjectNames.mkString(",") - logger.debug( - s"The provided Gradle project name `$projectName` is is not part of the valid project names: `$validProjectNamesStr`" - ) - None - else - Some(info) - else - Some(info) - catch - case t: Throwable => - logger.debug( - s"Caught exception while trying use Gradle connection: ${t.getMessage}" + private def getGradleProjectInfo( + projectDir: Path, + projectName: String + ): Option[GradleProjectInfo] = + Try(makeConnection(projectDir.toFile)) match + case Success(gradleConnection) => + Using.resource(gradleConnection) { connection => + try + val buildEnv = + connection.getModel[BuildEnvironment](classOf[BuildEnvironment]) + val project = connection.getModel[GradleProject](classOf[GradleProject]) + val hasAndroidPrefixGradleProperty = + runGradleTask(connection, Constants.gradlePropertiesTaskName) match + case Some(out) => + out.split('\n').exists( + _.startsWith(Constants.gradleAndroidPropertyPrefix) ) - logger.debug(s"Full exception: ", t) - None - } - case Failure(t) => + case None => false + val info = GradleProjectInfo( + buildEnv.getGradle.getGradleVersion, + project.getTasks.asScala.map(_.getName).toSeq, + hasAndroidPrefixGradleProperty + ) + if hasAndroidPrefixGradleProperty then + val validProjectNames = List( + project.getName + ) ++ project.getChildren.getAll.asScala.map(_.getName) + logger.debug( + s"Found Gradle projects: ${validProjectNames.mkString(",")}" + ) + if !validProjectNames.contains(projectName) then + val validProjectNamesStr = validProjectNames.mkString(",") + logger.debug( + s"The provided Gradle project name `$projectName` is is not part of the valid project names: `$validProjectNamesStr`" + ) + None + else + Some(info) + else + Some(info) + catch + case t: Throwable => + logger.debug( + s"Caught exception while trying use Gradle connection: ${t.getMessage}" + ) + logger.debug(s"Full exception: ", t) + None + } + case Failure(t) => + logger.debug( + s"Caught exception while trying fetch Gradle project information: ${t.getMessage}" + ) + logger.debug(s"Full exception: ", t) + None + + private def runGradleTask(connection: ProjectConnection, taskName: String): Option[String] = + Using.resource(new ByteArrayOutputStream()) { out => + Try( + connection + .newBuild() + .forTasks(taskName) + .setStandardOutput(out) + .run() + ) match + case Success(_) => Some(out.toString) + case Failure(ex) => logger.debug( - s"Caught exception while trying fetch Gradle project information: ${t.getMessage}" + s"Caught exception while executing Gradle task named `$taskName`:", + ex.getMessage ) - logger.debug(s"Full exception: ", t) + logger.debug(s"Full exception: ", ex) None + } - private def runGradleTask(connection: ProjectConnection, taskName: String): Option[String] = - Using.resource(new ByteArrayOutputStream()) { out => - Try( - connection - .newBuild() - .forTasks(taskName) - .setStandardOutput(out) - .run() - ) match - case Success(_) => Some(out.toString) + private def runGradleTask( + connection: ProjectConnection, + initScript: GradleDepsInitScript, + initScriptPath: String + ): Option[collection.Seq[String]] = + Using.resources(new ByteArrayOutputStream, new ByteArrayOutputStream) { + case (stdoutStream, stderrStream) => + logger.debug(s"Executing gradle task '${initScript.taskName}'...") + Try( + connection + .newBuild() + .forTasks(initScript.taskName) + .withArguments("--init-script", initScriptPath) + .setStandardOutput(stdoutStream) + .setStandardError(stderrStream) + .run() + ) match + case Success(_) => + val result = + Files + .list(initScript.destinationDir) + .collect(Collectors.toList[Path]) + .asScala + .map(_.toAbsolutePath.toString) + logger.debug(s"Resolved `${result.size}` dependency files.") + Some(result) case Failure(ex) => logger.debug( - s"Caught exception while executing Gradle task named `$taskName`:", - ex.getMessage + s"Caught exception while executing Gradle task: ${ex.getMessage}" ) - logger.debug(s"Full exception: ", ex) + logger.debug(s"Gradle task execution stdout: \n$stdoutStream") + logger.debug(s"Gradle task execution stderr: \n$stderrStream") None - } + end match + } - private def runGradleTask( - connection: ProjectConnection, - initScript: GradleDepsInitScript, - initScriptPath: String - ): Option[collection.Seq[String]] = - Using.resources(new ByteArrayOutputStream, new ByteArrayOutputStream) { - case (stdoutStream, stderrStream) => - logger.debug(s"Executing gradle task '${initScript.taskName}'...") - Try( - connection - .newBuild() - .forTasks(initScript.taskName) - .withArguments("--init-script", initScriptPath) - .setStandardOutput(stdoutStream) - .setStandardError(stderrStream) - .run() - ) match - case Success(_) => - val result = - Files - .list(initScript.destinationDir) - .collect(Collectors.toList[Path]) - .asScala - .map(_.toAbsolutePath.toString) - logger.debug(s"Resolved `${result.size}` dependency files.") - Some(result) - case Failure(ex) => - logger.debug( - s"Caught exception while executing Gradle task: ${ex.getMessage}" - ) - logger.debug(s"Gradle task execution stdout: \n$stdoutStream") - logger.debug(s"Gradle task execution stderr: \n$stderrStream") - None - end match - } + private def extractClassesJarFromAar(aar: File): Option[Path] = + val newPath = aar.path.toString.replaceFirst(Constants.aarFileExtension + "$", "jar") + val aarUnzipDirSuffix = ".unzipped" + val outDir = File(aar.path.toString + aarUnzipDirSuffix) + aar.unzipTo(outDir, _.getName == Constants.jarInsideAarFileName) + val outFile = File(newPath) + val classesJarEntries = + outDir.listRecursively + .filter(_.path.getFileName.toString == Constants.jarInsideAarFileName) + .toList + if classesJarEntries.size != 1 then + logger.debug(s"Found aar file without `classes.jar` inside at path ${aar.path}") + outDir.delete() + None + else + val classesJar = classesJarEntries.head + logger.trace(s"Copying `classes.jar` for aar at `${aar.path.toString}` into `$newPath`") + classesJar.copyTo(outFile) + outDir.delete() + aar.delete() + Some(outFile.path) + end extractClassesJarFromAar - private def extractClassesJarFromAar(aar: File): Option[Path] = - val newPath = aar.path.toString.replaceFirst(Constants.aarFileExtension + "$", "jar") - val aarUnzipDirSuffix = ".unzipped" - val outDir = File(aar.path.toString + aarUnzipDirSuffix) - aar.unzipTo(outDir, _.getName == Constants.jarInsideAarFileName) - val outFile = File(newPath) - val classesJarEntries = - outDir.listRecursively - .filter(_.path.getFileName.toString == Constants.jarInsideAarFileName) - .toList - if classesJarEntries.size != 1 then - logger.debug(s"Found aar file without `classes.jar` inside at path ${aar.path}") - outDir.delete() - None - else - val classesJar = classesJarEntries.head - logger.trace(s"Copying `classes.jar` for aar at `${aar.path.toString}` into `$newPath`") - classesJar.copyTo(outFile) - outDir.delete() - aar.delete() - Some(outFile.path) - end extractClassesJarFromAar - - // fetch the gradle project information first, then invoke a newly-defined gradle task to copy the necessary jars into - // a destination directory. - private[dependency] def get( - projectDir: Path, - projectName: String, - configurationName: String - ): Option[collection.Seq[String]] = - logger.debug( - s"Fetching Gradle project information at path `$projectDir` with project name `$projectName`." - ) - getGradleProjectInfo(projectDir, projectName) match - case Some(projectInfo) if projectInfo.gradleVersionMajorMinor()._1 < 5 => - logger.debug(s"Unsupported Gradle version `${projectInfo.gradleVersion}`") - None - case Some(projectInfo) => - Try(File.newTemporaryDirectory(tempDirPrefix).deleteOnExit()) match - case Success(destinationDir) => - Try(File.newTemporaryFile(initScriptPrefix).deleteOnExit()) match - case Success(initScriptFile) => - val initScript = - makeInitScript( - destinationDir.path, - projectInfo.hasAndroidSubproject, - projectName, - configurationName - ) - initScriptFile.write(initScript.contents) + // fetch the gradle project information first, then invoke a newly-defined gradle task to copy the necessary jars into + // a destination directory. + private[dependency] def get( + projectDir: Path, + projectName: String, + configurationName: String + ): Option[collection.Seq[String]] = + logger.debug( + s"Fetching Gradle project information at path `$projectDir` with project name `$projectName`." + ) + getGradleProjectInfo(projectDir, projectName) match + case Some(projectInfo) if projectInfo.gradleVersionMajorMinor()._1 < 5 => + logger.debug(s"Unsupported Gradle version `${projectInfo.gradleVersion}`") + None + case Some(projectInfo) => + Try(File.newTemporaryDirectory(tempDirPrefix).deleteOnExit()) match + case Success(destinationDir) => + Try(File.newTemporaryFile(initScriptPrefix).deleteOnExit()) match + case Success(initScriptFile) => + val initScript = + makeInitScript( + destinationDir.path, + projectInfo.hasAndroidSubproject, + projectName, + configurationName + ) + initScriptFile.write(initScript.contents) - logger.debug( - s"Downloading dependencies for configuration `$configurationName` of project `$projectName` at `$projectDir` into `$destinationDir`..." - ) - Try(makeConnection(projectDir.toFile)) match - case Success(connection) => - Using.resource(connection) { c => - runGradleTask( - c, - initScript, - initScriptFile.pathAsString - ) match - case Some(deps) => - Some(deps.map { d => - if !d.endsWith(Constants.aarFileExtension) - then d - else - extractClassesJarFromAar(File(d)) match - case Some(path) => path.toString - case None => d - }) - case None => None - } - case Failure(ex) => - logger.debug( - s"Caught exception while trying to establish a Gradle connection: ${ex.getMessage}" - ) - logger.debug(s"Full exception: ", ex) - None - end match - case Failure(ex) => - logger.debug( - s"Could not create temporary file for Gradle init script: ${ex.getMessage}" - ) - logger.debug(s"Full exception: ", ex) - None - case Failure(ex) => - logger.debug( - s"Could not create temporary directory for saving dependency files: ${ex.getMessage}" - ) - logger.debug("Full exception: ", ex) - None - case None => - logger.debug("Could not fetch Gradle project information") + logger.debug( + s"Downloading dependencies for configuration `$configurationName` of project `$projectName` at `$projectDir` into `$destinationDir`..." + ) + Try(makeConnection(projectDir.toFile)) match + case Success(connection) => + Using.resource(connection) { c => + runGradleTask( + c, + initScript, + initScriptFile.pathAsString + ) match + case Some(deps) => + Some(deps.map { d => + if !d.endsWith(Constants.aarFileExtension) + then d + else + extractClassesJarFromAar(File(d)) match + case Some(path) => path.toString + case None => d + }) + case None => None + } + case Failure(ex) => + logger.debug( + s"Caught exception while trying to establish a Gradle connection: ${ex.getMessage}" + ) + logger.debug(s"Full exception: ", ex) + None + end match + case Failure(ex) => + logger.debug( + s"Could not create temporary file for Gradle init script: ${ex.getMessage}" + ) + logger.debug(s"Full exception: ", ex) + None + case Failure(ex) => + logger.debug( + s"Could not create temporary directory for saving dependency files: ${ex.getMessage}" + ) + logger.debug("Full exception: ", ex) None - end match - end get + case None => + logger.debug("Could not fetch Gradle project information") + None + end match + end get end GradleDependencies diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/dependency/MavenCoordinates.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/dependency/MavenCoordinates.scala index 39f44dbe..7ee389d8 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/dependency/MavenCoordinates.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/dependency/MavenCoordinates.scala @@ -7,8 +7,8 @@ import java.nio.file.Path import scala.util.{Failure, Success} object MavenCoordinates: - private[dependency] def fromGradleOutput(lines: Seq[String]): Seq[String] = - /* + private[dependency] def fromGradleOutput(lines: Seq[String]): Seq[String] = + /* on the following regex, for the following input: ``` | | +--- org.springframework.boot:spring-boot-starter-logging:3.0.5 @@ -39,19 +39,19 @@ object MavenCoordinates: org.slf4j:jul-to-slf4j:2.0.7 ^g1 ------------------^^g2 ^ ``` - */ - val pattern = """^[| ]*[+\\]\s*[-]*\s*([^:]+:[^:]+:)([^\s]+)(\s+->\s+)?([^\s]+)?""".r - lines - .flatMap { l => - pattern.findFirstMatchIn(l) match - case Some(m) => - if Option(m.group(4)).isEmpty then - Some(m.group(1) + m.group(2)) - else - Some(m.group(1) + m.group(4)) - case _ => None - } - .distinct - .sorted - end fromGradleOutput + */ + val pattern = """^[| ]*[+\\]\s*[-]*\s*([^:]+:[^:]+:)([^\s]+)(\s+->\s+)?([^\s]+)?""".r + lines + .flatMap { l => + pattern.findFirstMatchIn(l) match + case Some(m) => + if Option(m.group(4)).isEmpty then + Some(m.group(1) + m.group(2)) + else + Some(m.group(1) + m.group(4)) + case _ => None + } + .distinct + .sorted + end fromGradleOutput end MavenCoordinates diff --git a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/dependency/MavenDependencies.scala b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/dependency/MavenDependencies.scala index 2decb863..a9e974a2 100644 --- a/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/dependency/MavenDependencies.scala +++ b/platform/frontends/x2cpg/src/main/scala/io/appthreat/x2cpg/utils/dependency/MavenDependencies.scala @@ -7,37 +7,37 @@ import java.nio.file.Path import scala.util.{Failure, Success} object MavenDependencies: - private val logger = LoggerFactory.getLogger(getClass) + private val logger = LoggerFactory.getLogger(getClass) - private[dependency] def get(projectDir: Path): Option[collection.Seq[String]] = - // we can't use -Dmdep.outputFile because that keeps overwriting its own output for each sub-project it's running for - val lines = ExternalCommand.run( - s"mvn -B dependency:build-classpath -DincludeScope=compile -Dorg.slf4j.simpleLogger.defaultLogLevel=info -Dorg.slf4j.simpleLogger.logFile=System.out", - projectDir.toString - ) match - case Success(lines) => lines - case Failure(exception) => - logger.debug( - s"Retrieval of compile class path via maven return with error.\n" + - "The compile class path may be missing or partial.\n" + - "Results will suffer from poor type information.\n\n" + - exception.getMessage - ) - // exception message is the program output - and we still want to look for potential partial results - exception.getMessage.linesIterator.toSeq + private[dependency] def get(projectDir: Path): Option[collection.Seq[String]] = + // we can't use -Dmdep.outputFile because that keeps overwriting its own output for each sub-project it's running for + val lines = ExternalCommand.run( + s"mvn -B dependency:build-classpath -DincludeScope=compile -Dorg.slf4j.simpleLogger.defaultLogLevel=info -Dorg.slf4j.simpleLogger.logFile=System.out", + projectDir.toString + ) match + case Success(lines) => lines + case Failure(exception) => + logger.debug( + s"Retrieval of compile class path via maven return with error.\n" + + "The compile class path may be missing or partial.\n" + + "Results will suffer from poor type information.\n\n" + + exception.getMessage + ) + // exception message is the program output - and we still want to look for potential partial results + exception.getMessage.linesIterator.toSeq - var classPathNext = false - val deps = lines - .flatMap { line => - val isClassPathNow = classPathNext - classPathNext = line.endsWith("Dependencies classpath:") + var classPathNext = false + val deps = lines + .flatMap { line => + val isClassPathNow = classPathNext + classPathNext = line.endsWith("Dependencies classpath:") - if isClassPathNow then line.split(':') else Array.empty[String] - } - .distinct - .toList + if isClassPathNow then line.split(':') else Array.empty[String] + } + .distinct + .toList - logger.debug("got {} Maven dependencies", deps.size) - Some(deps) - end get + logger.debug("got {} Maven dependencies", deps.size) + Some(deps) + end get end MavenDependencies diff --git a/platform/src/main/scala/io/appthreat/chencli/ChenExport.scala b/platform/src/main/scala/io/appthreat/chencli/ChenExport.scala index 7c32104f..012f23d6 100644 --- a/platform/src/main/scala/io/appthreat/chencli/ChenExport.scala +++ b/platform/src/main/scala/io/appthreat/chencli/ChenExport.scala @@ -25,220 +25,221 @@ import scala.util.Using object ChenExport: - case class Config( - cpgFileName: String = "cpg.bin", - outDir: String = "out", - repr: Representation.Value = Representation.Cpg14, - format: Format.Value = Format.Dot - ) - - /** Choose from either a subset of the graph, or the entire graph (all). - */ - object Representation extends Enumeration: - val Ast, Cfg, Ddg, Cdg, Pdg, Cpg14, Cpg, All = Value - - lazy val byNameLowercase: Map[String, Value] = - values.map { value => - value.toString.toLowerCase -> value - }.toMap - - def withNameIgnoreCase(s: String): Value = - byNameLowercase.getOrElse( - s, - throw new NoSuchElementException(s"No value found for '$s'") - ) + case class Config( + cpgFileName: String = "cpg.bin", + outDir: String = "out", + repr: Representation.Value = Representation.Cpg14, + format: Format.Value = Format.Dot + ) + + /** Choose from either a subset of the graph, or the entire graph (all). + */ + object Representation extends Enumeration: + val Ast, Cfg, Ddg, Cdg, Pdg, Cpg14, Cpg, All = Value + + lazy val byNameLowercase: Map[String, Value] = + values.map { value => + value.toString.toLowerCase -> value + }.toMap + + def withNameIgnoreCase(s: String): Value = + byNameLowercase.getOrElse( + s, + throw new NoSuchElementException(s"No value found for '$s'") + ) - object Format extends Enumeration: - val Dot, Neo4jCsv, Graphml, Graphson = Value + object Format extends Enumeration: + val Dot, Neo4jCsv, Graphml, Graphson = Value - lazy val byNameLowercase: Map[String, Value] = - values.map { value => - value.toString.toLowerCase -> value - }.toMap + lazy val byNameLowercase: Map[String, Value] = + values.map { value => + value.toString.toLowerCase -> value + }.toMap - def withNameIgnoreCase(s: String): Value = - byNameLowercase.getOrElse( - s, - throw new NoSuchElementException(s"No value found for '$s'") - ) + def withNameIgnoreCase(s: String): Value = + byNameLowercase.getOrElse( + s, + throw new NoSuchElementException(s"No value found for '$s'") + ) - def main(args: Array[String]): Unit = - parseConfig(args).foreach { config => - val outDir = config.outDir - exitIfInvalid(outDir, config.cpgFileName) - mkdir(File(outDir)) + def main(args: Array[String]): Unit = + parseConfig(args).foreach { config => + val outDir = config.outDir + exitIfInvalid(outDir, config.cpgFileName) + mkdir(File(outDir)) - Using.resource(CpgBasedTool.loadFromOdb(config.cpgFileName)) { cpg => - exportCpg(cpg, config.repr, config.format, Paths.get(outDir).toAbsolutePath) - } + Using.resource(CpgBasedTool.loadFromOdb(config.cpgFileName)) { cpg => + exportCpg(cpg, config.repr, config.format, Paths.get(outDir).toAbsolutePath) } + } - private def parseConfig(args: Array[String]): Option[Config] = - new scopt.OptionParser[Config]("chen-export"): - head( - "Dump intermediate graph representations (or entire graph) of code in a given export format" + private def parseConfig(args: Array[String]): Option[Config] = + new scopt.OptionParser[Config]("chen-export"): + head( + "Dump intermediate graph representations (or entire graph) of code in a given export format" + ) + help("help") + arg[String]("cpg") + .text("input CPG file name - defaults to `cpg.bin`") + .optional() + .action((x, c) => c.copy(cpgFileName = x)) + opt[String]('o', "out") + .text("output directory - will be created and must not yet exist") + .action((x, c) => c.copy(outDir = x)) + opt[String]("repr") + .text( + s"representation to extract: [${Representation.values.toSeq.map( + _.toString.toLowerCase + ).sorted.mkString("|")}] - defaults to `${Representation.Cpg14}`" ) - help("help") - arg[String]("cpg") - .text("input CPG file name - defaults to `cpg.bin`") - .optional() - .action((x, c) => c.copy(cpgFileName = x)) - opt[String]('o', "out") - .text("output directory - will be created and must not yet exist") - .action((x, c) => c.copy(outDir = x)) - opt[String]("repr") - .text( - s"representation to extract: [${Representation.values.toSeq.map( - _.toString.toLowerCase - ).sorted.mkString("|")}] - defaults to `${Representation.Cpg14}`" - ) - .action((x, c) => c.copy(repr = Representation.withNameIgnoreCase(x))) - opt[String]("format") - .action((x, c) => c.copy(format = Format.withNameIgnoreCase(x))) - .text( - s"export format, one of [${Format.values.toSeq.map(_.toString.toLowerCase).sorted.mkString("|")}] - defaults to `${Format.Dot}`" - ) - .parse(args, Config()) - - def exportCpg( - cpg: Cpg, - representation: Representation.Value, - format: Format.Value, - outDir: Path - ): Unit = - implicit val semantics: Semantics = DefaultSemantics() - if semantics.elements.isEmpty then - System.err.println("Warning: semantics are empty.") - - CpgBasedTool.addDataFlowOverlayIfNonExistent(cpg) - val context = new LayerCreatorContext(cpg) - - format match - case Format.Dot - if representation == Representation.All || representation == Representation.Cpg => - exportWithOdbFormat(cpg, representation, outDir, DotExporter) - case Format.Dot => - exportDot(representation, outDir, context) - case Format.Neo4jCsv => - exportWithOdbFormat(cpg, representation, outDir, Neo4jCsvExporter) - case Format.Graphml => - exportWithOdbFormat(cpg, representation, outDir, GraphMLExporter) - case Format.Graphson => - exportWithOdbFormat(cpg, representation, outDir, GraphSONExporter) - case other => - throw new NotImplementedError( - s"repr=$representation not yet supported for format=$format" - ) - end exportCpg - - private def exportDot( - repr: Representation.Value, - outDir: Path, - context: LayerCreatorContext - ): Unit = - val outDirStr = outDir.toString - import Representation.* - repr match - case Ast => new DumpAst(AstDumpOptions(outDirStr)).create(context) - case Cfg => new DumpCfg(CfgDumpOptions(outDirStr)).create(context) - case Ddg => new DumpDdg(DdgDumpOptions(outDirStr)).create(context) - case Cdg => new DumpCdg(CdgDumpOptions(outDirStr)).create(context) - case Pdg => new DumpPdg(PdgDumpOptions(outDirStr)).create(context) - case Cpg14 => new DumpCpg14(Cpg14DumpOptions(outDirStr)).create(context) - case other => - throw new NotImplementedError(s"repr=$repr not yet supported for this format") - - private def exportWithOdbFormat( - cpg: Cpg, - repr: Representation.Value, - outDir: Path, - exporter: overflowdb.formats.Exporter - ): Unit = - val ExportResult(nodeCount, edgeCount, _, additionalInfo) = repr match - case Representation.All => - exporter.runExport(cpg.graph, outDir) - case Representation.Cpg => - val windowsFilenameDeduplicationHelper = mutable.Set.empty[String] - splitByMethod(cpg).iterator - .map { case subGraph @ MethodSubGraph(methodName, methodFilename, nodes) => - val relativeFilename = sanitizedFileName( - methodName, - methodFilename, - exporter.defaultFileExtension, - windowsFilenameDeduplicationHelper - ) - val outFileName = outDir.resolve(relativeFilename) - exporter.runExport(nodes, subGraph.edges, outFileName) - } - .reduce(plus) - case other => - throw new NotImplementedError(s"repr=$repr not yet supported for this format") - - println(s"exported $nodeCount nodes, $edgeCount edges into $outDir") - additionalInfo.foreach(println) - end exportWithOdbFormat - - /** for each method in the cpg: recursively traverse all AST edges to get the subgraph of nodes - * within this method add the method and this subgraph to the export add all edges between all - * of these nodes to the export - */ - private def splitByMethod(cpg: Cpg): IterableOnce[MethodSubGraph] = - cpg.method.map { method => - MethodSubGraph( - methodName = method.name, - methodFilename = method.filename, - nodes = method.ast.toSet + .action((x, c) => c.copy(repr = Representation.withNameIgnoreCase(x))) + opt[String]("format") + .action((x, c) => c.copy(format = Format.withNameIgnoreCase(x))) + .text( + s"export format, one of [${Format.values.toSeq.map(_.toString.toLowerCase).sorted + .mkString("|")}] - defaults to `${Format.Dot}`" ) - } - - /** @param windowsFilenameDeduplicationHelper - * utility map to ensure we don't override output files for identical method names - */ - private def sanitizedFileName( - methodName: String, - methodFilename: String, - fileExtension: String, - windowsFilenameDeduplicationHelper: mutable.Set[String] - ): String = - val sanitizedMethodName = methodName.replaceAll("[^a-zA-Z0-9-_\\.]", "_") - val sanitizedFilename = - if scala.util.Properties.isWin then - // windows has some quirks in it's file system, e.g. we need to ensure paths aren't too long - so we're using a - // different strategy to sanitize windows file names: first occurrence of a given method uses the method name - // any methods with the same name afterwards get a `_` suffix - if windowsFilenameDeduplicationHelper.contains(sanitizedMethodName) then - sanitizedFileName( - s"${methodName}_", - methodFilename, - fileExtension, - windowsFilenameDeduplicationHelper - ) - else - windowsFilenameDeduplicationHelper.add(sanitizedMethodName) - sanitizedMethodName - else // non-windows - // handle leading `/` to ensure we're not writing outside of the output directory - val sanitizedPath = - if methodFilename.startsWith("/") then s"_root_/$methodFilename" - else methodFilename - s"$sanitizedPath/$sanitizedMethodName" - - s"$sanitizedFilename.$fileExtension" - end sanitizedFileName - - private def plus(resultA: ExportResult, resultB: ExportResult): ExportResult = - ExportResult( - nodeCount = resultA.nodeCount + resultB.nodeCount, - edgeCount = resultA.edgeCount + resultB.edgeCount, - files = resultA.files ++ resultB.files, - additionalInfo = resultA.additionalInfo - ) - - case class MethodSubGraph(methodName: String, methodFilename: String, nodes: Set[Node]): - def edges: Set[Edge] = - for - node <- nodes - edge <- node.bothE.asScala - if nodes.contains(edge.inNode) && nodes.contains(edge.outNode) - yield edge + .parse(args, Config()) + + def exportCpg( + cpg: Cpg, + representation: Representation.Value, + format: Format.Value, + outDir: Path + ): Unit = + implicit val semantics: Semantics = DefaultSemantics() + if semantics.elements.isEmpty then + System.err.println("Warning: semantics are empty.") + + CpgBasedTool.addDataFlowOverlayIfNonExistent(cpg) + val context = new LayerCreatorContext(cpg) + + format match + case Format.Dot + if representation == Representation.All || representation == Representation.Cpg => + exportWithOdbFormat(cpg, representation, outDir, DotExporter) + case Format.Dot => + exportDot(representation, outDir, context) + case Format.Neo4jCsv => + exportWithOdbFormat(cpg, representation, outDir, Neo4jCsvExporter) + case Format.Graphml => + exportWithOdbFormat(cpg, representation, outDir, GraphMLExporter) + case Format.Graphson => + exportWithOdbFormat(cpg, representation, outDir, GraphSONExporter) + case other => + throw new NotImplementedError( + s"repr=$representation not yet supported for format=$format" + ) + end exportCpg + + private def exportDot( + repr: Representation.Value, + outDir: Path, + context: LayerCreatorContext + ): Unit = + val outDirStr = outDir.toString + import Representation.* + repr match + case Ast => new DumpAst(AstDumpOptions(outDirStr)).create(context) + case Cfg => new DumpCfg(CfgDumpOptions(outDirStr)).create(context) + case Ddg => new DumpDdg(DdgDumpOptions(outDirStr)).create(context) + case Cdg => new DumpCdg(CdgDumpOptions(outDirStr)).create(context) + case Pdg => new DumpPdg(PdgDumpOptions(outDirStr)).create(context) + case Cpg14 => new DumpCpg14(Cpg14DumpOptions(outDirStr)).create(context) + case other => + throw new NotImplementedError(s"repr=$repr not yet supported for this format") + + private def exportWithOdbFormat( + cpg: Cpg, + repr: Representation.Value, + outDir: Path, + exporter: overflowdb.formats.Exporter + ): Unit = + val ExportResult(nodeCount, edgeCount, _, additionalInfo) = repr match + case Representation.All => + exporter.runExport(cpg.graph, outDir) + case Representation.Cpg => + val windowsFilenameDeduplicationHelper = mutable.Set.empty[String] + splitByMethod(cpg).iterator + .map { case subGraph @ MethodSubGraph(methodName, methodFilename, nodes) => + val relativeFilename = sanitizedFileName( + methodName, + methodFilename, + exporter.defaultFileExtension, + windowsFilenameDeduplicationHelper + ) + val outFileName = outDir.resolve(relativeFilename) + exporter.runExport(nodes, subGraph.edges, outFileName) + } + .reduce(plus) + case other => + throw new NotImplementedError(s"repr=$repr not yet supported for this format") + + println(s"exported $nodeCount nodes, $edgeCount edges into $outDir") + additionalInfo.foreach(println) + end exportWithOdbFormat + + /** for each method in the cpg: recursively traverse all AST edges to get the subgraph of nodes + * within this method add the method and this subgraph to the export add all edges between all of + * these nodes to the export + */ + private def splitByMethod(cpg: Cpg): IterableOnce[MethodSubGraph] = + cpg.method.map { method => + MethodSubGraph( + methodName = method.name, + methodFilename = method.filename, + nodes = method.ast.toSet + ) + } + + /** @param windowsFilenameDeduplicationHelper + * utility map to ensure we don't override output files for identical method names + */ + private def sanitizedFileName( + methodName: String, + methodFilename: String, + fileExtension: String, + windowsFilenameDeduplicationHelper: mutable.Set[String] + ): String = + val sanitizedMethodName = methodName.replaceAll("[^a-zA-Z0-9-_\\.]", "_") + val sanitizedFilename = + if scala.util.Properties.isWin then + // windows has some quirks in it's file system, e.g. we need to ensure paths aren't too long - so we're using a + // different strategy to sanitize windows file names: first occurrence of a given method uses the method name + // any methods with the same name afterwards get a `_` suffix + if windowsFilenameDeduplicationHelper.contains(sanitizedMethodName) then + sanitizedFileName( + s"${methodName}_", + methodFilename, + fileExtension, + windowsFilenameDeduplicationHelper + ) + else + windowsFilenameDeduplicationHelper.add(sanitizedMethodName) + sanitizedMethodName + else // non-windows + // handle leading `/` to ensure we're not writing outside of the output directory + val sanitizedPath = + if methodFilename.startsWith("/") then s"_root_/$methodFilename" + else methodFilename + s"$sanitizedPath/$sanitizedMethodName" + + s"$sanitizedFilename.$fileExtension" + end sanitizedFileName + + private def plus(resultA: ExportResult, resultB: ExportResult): ExportResult = + ExportResult( + nodeCount = resultA.nodeCount + resultB.nodeCount, + edgeCount = resultA.edgeCount + resultB.edgeCount, + files = resultA.files ++ resultB.files, + additionalInfo = resultA.additionalInfo + ) + + case class MethodSubGraph(methodName: String, methodFilename: String, nodes: Set[Node]): + def edges: Set[Edge] = + for + node <- nodes + edge <- node.bothE.asScala + if nodes.contains(edge.inNode) && nodes.contains(edge.outNode) + yield edge end ChenExport diff --git a/platform/src/main/scala/io/appthreat/chencli/ChenFlow.scala b/platform/src/main/scala/io/appthreat/chencli/ChenFlow.scala index 2f758fb7..2fe02c8f 100644 --- a/platform/src/main/scala/io/appthreat/chencli/ChenFlow.scala +++ b/platform/src/main/scala/io/appthreat/chencli/ChenFlow.scala @@ -20,88 +20,88 @@ case class FlowConfig( ) object ChenFlow: - def main(args: Array[String]) = - parseConfig(args).foreach { config => - def debugOut(msg: String): Unit = - if config.verbose then - print(msg) - - debugOut("Loading graph... ") - val cpg = CpgBasedTool.loadFromOdb(config.cpgFileName) - debugOut("[DONE]\n") - - implicit val resolver: ICallResolver = NoResolve - val sources = params(cpg, config.srcRegex, config.srcParam) - val sinks = params(cpg, config.dstRegex, config.dstParam) - - debugOut(s"Number of sources: ${sources.size}\n") - debugOut(s"Number of sinks: ${sinks.size}\n") - - implicit val semantics: Semantics = DefaultSemantics() - val engineConfig = EngineConfig(config.depth) - debugOut(s"Analysis depth: ${engineConfig.maxCallDepth}\n") - implicit val context: EngineContext = EngineContext(semantics, engineConfig) - - debugOut("Determining flows...") - sinks.foreach { s => - List(s).reachableByFlows(sources.iterator).p.foreach(println) - } - debugOut("[DONE]") - - debugOut("Closing graph... ") - cpg.close() - debugOut("[DONE]\n") + def main(args: Array[String]) = + parseConfig(args).foreach { config => + def debugOut(msg: String): Unit = + if config.verbose then + print(msg) + + debugOut("Loading graph... ") + val cpg = CpgBasedTool.loadFromOdb(config.cpgFileName) + debugOut("[DONE]\n") + + implicit val resolver: ICallResolver = NoResolve + val sources = params(cpg, config.srcRegex, config.srcParam) + val sinks = params(cpg, config.dstRegex, config.dstParam) + + debugOut(s"Number of sources: ${sources.size}\n") + debugOut(s"Number of sinks: ${sinks.size}\n") + + implicit val semantics: Semantics = DefaultSemantics() + val engineConfig = EngineConfig(config.depth) + debugOut(s"Analysis depth: ${engineConfig.maxCallDepth}\n") + implicit val context: EngineContext = EngineContext(semantics, engineConfig) + + debugOut("Determining flows...") + sinks.foreach { s => + List(s).reachableByFlows(sources.iterator).p.foreach(println) } - - private def parseConfig(args: Array[String]): Option[FlowConfig] = { - new scopt.OptionParser[FlowConfig]("chen-flow"): - head("Find flows") - help("help") - - arg[String]("src") - .text("source regex") - .action((x, c) => c.copy(srcRegex = x)) - - arg[String]("dst") - .text("destination regex") - .action((x, c) => c.copy(dstRegex = x)) - - arg[String]("cpg") - .text("CPG file name ('cpg.bin' by default)") - .optional() - .action((x, c) => c.copy(cpgFileName = x)) - - opt[Int]("src-param") - .text("Source parameter") - .optional() - .action((x, c) => c.copy(dstParam = Some(x))) - - opt[Int]("dst-param") - .text("Destination parameter") - .optional() - .action((x, c) => c.copy(dstParam = Some(x))) - - opt[Int]("depth") - .text("Analysis depth (number of calls to expand)") - .optional() - .action((x, c) => c.copy(depth = x)) - - opt[Unit]("verbose") - .text("Print debug information") - .optional() - .action((_, c) => c.copy(verbose = true)) - }.parse(args, FlowConfig()) - - private def params( - cpg: Cpg, - methodNameRegex: String, - paramIndex: Option[Int] - ): List[MethodParameterIn] = - cpg - .method(methodNameRegex) - .parameter - .filter { p => - paramIndex.isEmpty || paramIndex.contains(p.order) - } - .l + debugOut("[DONE]") + + debugOut("Closing graph... ") + cpg.close() + debugOut("[DONE]\n") + } + + private def parseConfig(args: Array[String]): Option[FlowConfig] = { + new scopt.OptionParser[FlowConfig]("chen-flow"): + head("Find flows") + help("help") + + arg[String]("src") + .text("source regex") + .action((x, c) => c.copy(srcRegex = x)) + + arg[String]("dst") + .text("destination regex") + .action((x, c) => c.copy(dstRegex = x)) + + arg[String]("cpg") + .text("CPG file name ('cpg.bin' by default)") + .optional() + .action((x, c) => c.copy(cpgFileName = x)) + + opt[Int]("src-param") + .text("Source parameter") + .optional() + .action((x, c) => c.copy(dstParam = Some(x))) + + opt[Int]("dst-param") + .text("Destination parameter") + .optional() + .action((x, c) => c.copy(dstParam = Some(x))) + + opt[Int]("depth") + .text("Analysis depth (number of calls to expand)") + .optional() + .action((x, c) => c.copy(depth = x)) + + opt[Unit]("verbose") + .text("Print debug information") + .optional() + .action((_, c) => c.copy(verbose = true)) + }.parse(args, FlowConfig()) + + private def params( + cpg: Cpg, + methodNameRegex: String, + paramIndex: Option[Int] + ): List[MethodParameterIn] = + cpg + .method(methodNameRegex) + .parameter + .filter { p => + paramIndex.isEmpty || paramIndex.contains(p.order) + } + .l end ChenFlow diff --git a/platform/src/main/scala/io/appthreat/chencli/ChenParse.scala b/platform/src/main/scala/io/appthreat/chencli/ChenParse.scala index d6fafe09..9c0f1bb1 100644 --- a/platform/src/main/scala/io/appthreat/chencli/ChenParse.scala +++ b/platform/src/main/scala/io/appthreat/chencli/ChenParse.scala @@ -11,174 +11,174 @@ import scala.jdk.CollectionConverters.* import scala.util.{Failure, Success, Try} object ChenParse: - // Special string used to separate joern-parse opts from frontend-specific opts - val ArgsDelimitor = "--frontend-args" - val DefaultCpgOutFile = "cpg.bin" - var generator: CpgGenerator = scala.compiletime.uninitialized - - def main(args: Array[String]): Unit = - run(args) match - case Success(msg) => - println(msg) - case Failure(err) => - err.printStackTrace() - System.exit(1) - - val optionParser = new scopt.OptionParser[ParserConfig]("chen-parse"): - arg[String]("input") - .optional() - .text("source file or directory containing source files") - .action((x, c) => c.copy(inputPath = x)) - - opt[String]('o', "output") - .text("output filename") - .action((x, c) => c.copy(outputCpgFile = x)) - - opt[String]("language") - .text("source language") - .action((x, c) => c.copy(language = x)) - - opt[Unit]("list-languages") - .text("list available language options") - .action((_, c) => c.copy(listLanguages = true)) - - opt[String]("namespaces") - .text("namespaces to include: comma separated string") - .action((x, c) => c.copy(namespaces = x.split(",").map(_.trim).toSeq)) - - note("Overlay application stage") - - opt[Unit]("nooverlays") - .text("do not apply default overlays") - .action((_, c) => c.copy(enhance = false)) - opt[Unit]("overlaysonly") - .text("Only apply default overlays") - .action((_, c) => c.copy(enhanceOnly = true)) - - opt[Int]("max-num-def") - .text("Maximum number of definitions in per-method data flow calculation") - .action((x, c) => c.copy(maxNumDef = x)) - - note("Misc") - help("help").text("display this help message") - - note( - s"Args specified after the $ArgsDelimitor separator will be passed to the front-end verbatim" - ) - - private def run(args: Array[String]): Try[String] = - val (parserArgs, frontendArgs) = CpgBasedTool.splitArgs(args) - val installConfig = new InstallConfig() - - parseConfig(parserArgs).flatMap { config => - if config.listLanguages then - Try(buildLanguageList()) - else - run(config, frontendArgs, installConfig) - } - - def run( - config: ParserConfig, - frontendArgs: List[String] = List.empty, - installConfig: InstallConfig = InstallConfig() - ): Try[String] = - for - _ <- checkInputPath(config) - language <- getLanguage(config) - _ <- generateCpg(installConfig, frontendArgs, config, language) - _ <- applyDefaultOverlays(config) - yield newCpgCreatedString(config.outputCpgFile) - - private def checkInputPath(config: ParserConfig): Try[Unit] = - Try { - if config.inputPath == "" then - println(optionParser.usage) - throw new AssertionError(s"Input path required") - else if !File(config.inputPath).exists then - throw new AssertionError( - s"Input path does not exist at `${config.inputPath}`, exiting." - ) - else () - } - - private def buildLanguageList(): String = - val s = new mutable.StringBuilder() - s ++= "Available languages (case insensitive):\n" - s ++= Languages.ALL.asScala.map(lang => s"- ${lang.toLowerCase}").mkString("\n") - s.toString() + // Special string used to separate joern-parse opts from frontend-specific opts + val ArgsDelimitor = "--frontend-args" + val DefaultCpgOutFile = "cpg.bin" + var generator: CpgGenerator = scala.compiletime.uninitialized + + def main(args: Array[String]): Unit = + run(args) match + case Success(msg) => + println(msg) + case Failure(err) => + err.printStackTrace() + System.exit(1) + + val optionParser = new scopt.OptionParser[ParserConfig]("chen-parse"): + arg[String]("input") + .optional() + .text("source file or directory containing source files") + .action((x, c) => c.copy(inputPath = x)) + + opt[String]('o', "output") + .text("output filename") + .action((x, c) => c.copy(outputCpgFile = x)) + + opt[String]("language") + .text("source language") + .action((x, c) => c.copy(language = x)) + + opt[Unit]("list-languages") + .text("list available language options") + .action((_, c) => c.copy(listLanguages = true)) + + opt[String]("namespaces") + .text("namespaces to include: comma separated string") + .action((x, c) => c.copy(namespaces = x.split(",").map(_.trim).toSeq)) + + note("Overlay application stage") + + opt[Unit]("nooverlays") + .text("do not apply default overlays") + .action((_, c) => c.copy(enhance = false)) + opt[Unit]("overlaysonly") + .text("Only apply default overlays") + .action((_, c) => c.copy(enhanceOnly = true)) + + opt[Int]("max-num-def") + .text("Maximum number of definitions in per-method data flow calculation") + .action((x, c) => c.copy(maxNumDef = x)) + + note("Misc") + help("help").text("display this help message") + + note( + s"Args specified after the $ArgsDelimitor separator will be passed to the front-end verbatim" + ) - private def getLanguage(config: ParserConfig): Try[String] = - Try { - if config.language.nonEmpty then - config.language - else - guessLanguage(config.inputPath) - .getOrElse( - throw new AssertionError( - s"Could not guess language from input path ${config.inputPath}. Please specify a language using the --language option." - ) - ) - } + private def run(args: Array[String]): Try[String] = + val (parserArgs, frontendArgs) = CpgBasedTool.splitArgs(args) + val installConfig = new InstallConfig() - private def generateCpg( - installConfig: InstallConfig, - frontendArgs: Seq[String], - config: ParserConfig, - language: String - ): Try[String] = - if config.enhanceOnly then - Success("No generation required") + parseConfig(parserArgs).flatMap { config => + if config.listLanguages then + Try(buildLanguageList()) else - println(s"Parsing code at: ${config.inputPath} - language: `$language`") - println("[+] Running language frontend") - Try { - cpgGeneratorForLanguage( - language.toUpperCase, - FrontendConfig(), - installConfig.rootPath.path, - frontendArgs.toList - ).get - }.flatMap { newGenerator => - generator = newGenerator - generator - .generate(config.inputPath, outputPath = config.outputCpgFile) - .recover { case exception => - throw new RuntimeException( - s"Could not generate CPG with language = $language and input = ${config.inputPath}", - exception - ) - } - } - - private def applyDefaultOverlays(config: ParserConfig): Try[String] = - Try { - println("[+] Applying default overlays") - if config.enhance then - val cpg = DefaultOverlays.create(config.outputCpgFile, config.maxNumDef) - generator.applyPostProcessingPasses(cpg) - cpg.close() - "Code property graph generation successful" - } - - case class ParserConfig( - inputPath: String = "", - outputCpgFile: String = DefaultCpgOutFile, - namespaces: Seq[String] = Seq.empty, - enhance: Boolean = true, - enhanceOnly: Boolean = false, - language: String = "", - listLanguages: Boolean = false, - maxNumDef: Int = DefaultOverlays.defaultMaxNumberOfDefinitions - ) - - private def parseConfig(parserArgs: Seq[String]): Try[ParserConfig] = - Try { - optionParser - .parse(parserArgs, ParserConfig()) + run(config, frontendArgs, installConfig) + } + + def run( + config: ParserConfig, + frontendArgs: List[String] = List.empty, + installConfig: InstallConfig = InstallConfig() + ): Try[String] = + for + _ <- checkInputPath(config) + language <- getLanguage(config) + _ <- generateCpg(installConfig, frontendArgs, config, language) + _ <- applyDefaultOverlays(config) + yield newCpgCreatedString(config.outputCpgFile) + + private def checkInputPath(config: ParserConfig): Try[Unit] = + Try { + if config.inputPath == "" then + println(optionParser.usage) + throw new AssertionError(s"Input path required") + else if !File(config.inputPath).exists then + throw new AssertionError( + s"Input path does not exist at `${config.inputPath}`, exiting." + ) + else () + } + + private def buildLanguageList(): String = + val s = new mutable.StringBuilder() + s ++= "Available languages (case insensitive):\n" + s ++= Languages.ALL.asScala.map(lang => s"- ${lang.toLowerCase}").mkString("\n") + s.toString() + + private def getLanguage(config: ParserConfig): Try[String] = + Try { + if config.language.nonEmpty then + config.language + else + guessLanguage(config.inputPath) .getOrElse( - throw new RuntimeException( - s"Error while not parsing command line options: `${parserArgs.mkString(",")}`" + throw new AssertionError( + s"Could not guess language from input path ${config.inputPath}. Please specify a language using the --language option." ) ) + } + + private def generateCpg( + installConfig: InstallConfig, + frontendArgs: Seq[String], + config: ParserConfig, + language: String + ): Try[String] = + if config.enhanceOnly then + Success("No generation required") + else + println(s"Parsing code at: ${config.inputPath} - language: `$language`") + println("[+] Running language frontend") + Try { + cpgGeneratorForLanguage( + language.toUpperCase, + FrontendConfig(), + installConfig.rootPath.path, + frontendArgs.toList + ).get + }.flatMap { newGenerator => + generator = newGenerator + generator + .generate(config.inputPath, outputPath = config.outputCpgFile) + .recover { case exception => + throw new RuntimeException( + s"Could not generate CPG with language = $language and input = ${config.inputPath}", + exception + ) + } } + + private def applyDefaultOverlays(config: ParserConfig): Try[String] = + Try { + println("[+] Applying default overlays") + if config.enhance then + val cpg = DefaultOverlays.create(config.outputCpgFile, config.maxNumDef) + generator.applyPostProcessingPasses(cpg) + cpg.close() + "Code property graph generation successful" + } + + case class ParserConfig( + inputPath: String = "", + outputCpgFile: String = DefaultCpgOutFile, + namespaces: Seq[String] = Seq.empty, + enhance: Boolean = true, + enhanceOnly: Boolean = false, + language: String = "", + listLanguages: Boolean = false, + maxNumDef: Int = DefaultOverlays.defaultMaxNumberOfDefinitions + ) + + private def parseConfig(parserArgs: Seq[String]): Try[ParserConfig] = + Try { + optionParser + .parse(parserArgs, ParserConfig()) + .getOrElse( + throw new RuntimeException( + s"Error while not parsing command line options: `${parserArgs.mkString(",")}`" + ) + ) + } end ChenParse diff --git a/platform/src/main/scala/io/appthreat/chencli/ChenVectors.scala b/platform/src/main/scala/io/appthreat/chencli/ChenVectors.scala index 6f13ea7b..ef8991db 100644 --- a/platform/src/main/scala/io/appthreat/chencli/ChenVectors.scala +++ b/platform/src/main/scala/io/appthreat/chencli/ChenVectors.scala @@ -15,41 +15,41 @@ import scala.util.Using import scala.util.hashing.MurmurHash3 class BagOfPropertiesForNodes extends EmbeddingGenerator[AstNode, (String, String)]: - override def structureToString(pair: (String, String)): String = pair._1 + ":" + pair._2 - override def extractObjects(cpg: Cpg): Iterator[AstNode] = - cpg.graph.V.collect { case x: AstNode => x } - override def enumerateSubStructures(obj: AstNode): List[(String, String)] = - val relevantFieldTypes = - Set(PropertyNames.NAME, PropertyNames.FULL_NAME, PropertyNames.CODE) - val relevantFields = obj - .propertiesMap() - .entrySet() - .asScala - .toList - .filter { e => relevantFieldTypes.contains(e.getKey) } - .sortBy(_.getKey) - .map { e => - (e.getKey, e.getValue.toString) - } - List(("id", obj.id().toString)) ++ relevantFields ++ List(("label", obj.label)) - - override def objectToString(node: AstNode): String = node.id().toString - override def hash(label: String): String = label - - override def vectorToString(vector: Map[(String, String), Double]): String = - val jsonObj = Json.fromFields(vector.keys.map { case (k, v) => (k, Json.fromString(v)) }) - jsonObj.toString + override def structureToString(pair: (String, String)): String = pair._1 + ":" + pair._2 + override def extractObjects(cpg: Cpg): Iterator[AstNode] = + cpg.graph.V.collect { case x: AstNode => x } + override def enumerateSubStructures(obj: AstNode): List[(String, String)] = + val relevantFieldTypes = + Set(PropertyNames.NAME, PropertyNames.FULL_NAME, PropertyNames.CODE) + val relevantFields = obj + .propertiesMap() + .entrySet() + .asScala + .toList + .filter { e => relevantFieldTypes.contains(e.getKey) } + .sortBy(_.getKey) + .map { e => + (e.getKey, e.getValue.toString) + } + List(("id", obj.id().toString)) ++ relevantFields ++ List(("label", obj.label)) + + override def objectToString(node: AstNode): String = node.id().toString + override def hash(label: String): String = label + + override def vectorToString(vector: Map[(String, String), Double]): String = + val jsonObj = Json.fromFields(vector.keys.map { case (k, v) => (k, Json.fromString(v)) }) + jsonObj.toString end BagOfPropertiesForNodes class BagOfAPISymbolsForMethods extends EmbeddingGenerator[Method, AstNode]: - override def extractObjects(cpg: Cpg): Iterator[Method] = cpg.method - override def enumerateSubStructures(method: Method): List[AstNode] = method.ast.l - override def structureToString(astNode: AstNode): String = astNode.code - override def objectToString(method: Method): String = method.fullName + override def extractObjects(cpg: Cpg): Iterator[Method] = cpg.method + override def enumerateSubStructures(method: Method): List[AstNode] = method.ast.l + override def structureToString(astNode: AstNode): String = astNode.code + override def objectToString(method: Method): String = method.fullName object EmbeddingGenerator: - type SparseVectorWithExplicitFeature[S] = Map[String, (Double, S)] - type SparseVector = Map[Int, Double] + type SparseVectorWithExplicitFeature[S] = Map[String, (Double, S)] + type SparseVector = Map[Int, Double] /** Creates an embedding from a code property graph by following three steps: (1) Objects are * extracted from the graph, each of which is ultimately to be mapped to one vector (2) For each @@ -59,121 +59,121 @@ object EmbeddingGenerator: * T: Object type S: Sub structure type */ trait EmbeddingGenerator[T, S]: - import EmbeddingGenerator.* - - case class Embedding(data: () => Iterator[(T, SparseVectorWithExplicitFeature[S])]): - lazy val dimToStructure: Map[String, S] = - val m = mutable.HashMap[String, S]() - data().foreach { case (_, vector) => - vector.foreach { case (hash, (_, structure)) => - m.put(hash, structure) - } - } - m.toMap - - lazy val structureToDim: Map[S, String] = for ((k, v) <- dimToStructure) yield (v, k) - - def objects: Iterator[String] = data().map { case (obj, _) => objectToString(obj) } - - def vectors: Iterator[Map[S, Double]] = data().map { case (_, vector) => - vector.map { case (_, (v, structure)) => structure -> v } - } - - /** Extract a sequence of (object, vector) pairs from a cpg. - */ - def embed(cpg: Cpg): Embedding = - Embedding({ () => - extractObjects(cpg) - .map { obj => - val substructures = enumerateSubStructures(obj) - obj -> vectorize(substructures) - } - }) - - private def vectorize(substructures: List[S]): Map[String, (Double, S)] = - substructures - .groupBy(x => structureToString(x)) - .view - .map { case (_, l) => - val v = l.size - hash(structureToString(l.head)) -> (v.toDouble, l.head) - } - .toMap - - def hash(label: String): String = MurmurHash3.stringHash(label).toString - - def structureToString(s: S): String - - /** A function that creates a sequence of objects from a CPG - */ - def extractObjects(cpg: Cpg): Iterator[T] - - /** A function that, for a given object, extracts its sub structures - */ - def enumerateSubStructures(obj: T): List[S] - - def objectToString(t: T): String - - implicit val formats: DefaultFormats.type = org.json4s.DefaultFormats - - def vectorToString(vector: Map[S, Double]): String = defaultToString(vector) - - def defaultToString[M](v: M): String = Serialization.write(v) + import EmbeddingGenerator.* + + case class Embedding(data: () => Iterator[(T, SparseVectorWithExplicitFeature[S])]): + lazy val dimToStructure: Map[String, S] = + val m = mutable.HashMap[String, S]() + data().foreach { case (_, vector) => + vector.foreach { case (hash, (_, structure)) => + m.put(hash, structure) + } + } + m.toMap + + lazy val structureToDim: Map[S, String] = for ((k, v) <- dimToStructure) yield (v, k) + + def objects: Iterator[String] = data().map { case (obj, _) => objectToString(obj) } + + def vectors: Iterator[Map[S, Double]] = data().map { case (_, vector) => + vector.map { case (_, (v, structure)) => structure -> v } + } + + /** Extract a sequence of (object, vector) pairs from a cpg. + */ + def embed(cpg: Cpg): Embedding = + Embedding({ () => + extractObjects(cpg) + .map { obj => + val substructures = enumerateSubStructures(obj) + obj -> vectorize(substructures) + } + }) + + private def vectorize(substructures: List[S]): Map[String, (Double, S)] = + substructures + .groupBy(x => structureToString(x)) + .view + .map { case (_, l) => + val v = l.size + hash(structureToString(l.head)) -> (v.toDouble, l.head) + } + .toMap + + def hash(label: String): String = MurmurHash3.stringHash(label).toString + + def structureToString(s: S): String + + /** A function that creates a sequence of objects from a CPG + */ + def extractObjects(cpg: Cpg): Iterator[T] + + /** A function that, for a given object, extracts its sub structures + */ + def enumerateSubStructures(obj: T): List[S] + + def objectToString(t: T): String + + implicit val formats: DefaultFormats.type = org.json4s.DefaultFormats + + def vectorToString(vector: Map[S, Double]): String = defaultToString(vector) + + def defaultToString[M](v: M): String = Serialization.write(v) end EmbeddingGenerator object JoernVectors: - implicit val formats: DefaultFormats.type = org.json4s.DefaultFormats - case class Config( - cpgFileName: String = "cpg.bin", - outDir: String = "out", - dimToFeature: Boolean = false - ) - - def main(args: Array[String]) = - parseConfig(args).foreach { config => - exitIfInvalid(config.outDir, config.cpgFileName) - Using.resource(CpgBasedTool.loadFromOdb(config.cpgFileName)) { cpg => - val generator = new BagOfPropertiesForNodes() - val embedding = generator.embed(cpg) - println("{") - println("\"objects\":") - traversalToJson(embedding.objects, generator.defaultToString) - if config.dimToFeature then - println(",\"dimToFeature\": ") - println(Serialization.write(embedding.dimToStructure)) - println(",\"vectors\":") - traversalToJson(embedding.vectors, generator.vectorToString) - println(",\"edges\":") - traversalToJson( - cpg.graph.edges().map { x => - Map("src" -> x.outNode().id(), "dst" -> x.inNode().id(), "label" -> x.label()) - }, - generator.defaultToString - ) - println("}") - } + implicit val formats: DefaultFormats.type = org.json4s.DefaultFormats + case class Config( + cpgFileName: String = "cpg.bin", + outDir: String = "out", + dimToFeature: Boolean = false + ) + + def main(args: Array[String]) = + parseConfig(args).foreach { config => + exitIfInvalid(config.outDir, config.cpgFileName) + Using.resource(CpgBasedTool.loadFromOdb(config.cpgFileName)) { cpg => + val generator = new BagOfPropertiesForNodes() + val embedding = generator.embed(cpg) + println("{") + println("\"objects\":") + traversalToJson(embedding.objects, generator.defaultToString) + if config.dimToFeature then + println(",\"dimToFeature\": ") + println(Serialization.write(embedding.dimToStructure)) + println(",\"vectors\":") + traversalToJson(embedding.vectors, generator.vectorToString) + println(",\"edges\":") + traversalToJson( + cpg.graph.edges().map { x => + Map("src" -> x.outNode().id(), "dst" -> x.inNode().id(), "label" -> x.label()) + }, + generator.defaultToString + ) + println("}") } - - private def parseConfig(args: Array[String]): Option[Config] = - new scopt.OptionParser[Config]("chen-vectors"): - head("Extract vector representations of code from CPG") - help("help") - arg[String]("cpg") - .text("input CPG file name - defaults to `cpg.bin`") - .optional() - .action((x, c) => c.copy(cpgFileName = x)) - opt[String]('o', "out") - .text("output directory - will be created and must not yet exist") - .action((x, c) => c.copy(outDir = x)) - opt[Unit]("features") - .text("Provide map from dimensions to features") - .action((_, c) => c.copy(dimToFeature = true)) - .parse(args, Config()) - - private def traversalToJson[X](trav: Iterator[X], vectorToString: X => String): Unit = - println("[") - trav.nextOption().foreach { vector => print(vectorToString(vector)) } - trav.foreach { vector => print(",\n" + vectorToString(vector)) } - println("]") + } + + private def parseConfig(args: Array[String]): Option[Config] = + new scopt.OptionParser[Config]("chen-vectors"): + head("Extract vector representations of code from CPG") + help("help") + arg[String]("cpg") + .text("input CPG file name - defaults to `cpg.bin`") + .optional() + .action((x, c) => c.copy(cpgFileName = x)) + opt[String]('o', "out") + .text("output directory - will be created and must not yet exist") + .action((x, c) => c.copy(outDir = x)) + opt[Unit]("features") + .text("Provide map from dimensions to features") + .action((_, c) => c.copy(dimToFeature = true)) + .parse(args, Config()) + + private def traversalToJson[X](trav: Iterator[X], vectorToString: X => String): Unit = + println("[") + trav.nextOption().foreach { vector => print(vectorToString(vector)) } + trav.foreach { vector => print(",\n" + vectorToString(vector)) } + println("]") end JoernVectors diff --git a/platform/src/main/scala/io/appthreat/chencli/CpgBasedTool.scala b/platform/src/main/scala/io/appthreat/chencli/CpgBasedTool.scala index 7563937d..b6438078 100644 --- a/platform/src/main/scala/io/appthreat/chencli/CpgBasedTool.scala +++ b/platform/src/main/scala/io/appthreat/chencli/CpgBasedTool.scala @@ -10,51 +10,51 @@ import io.shiftleft.semanticcpg.layers.LayerCreatorContext object CpgBasedTool: - /** Load code property graph from overflowDB - * - * @param filename - * name of the file that stores the CPG - */ - def loadFromOdb(filename: String): Cpg = - val odbConfig = overflowdb.Config.withDefaults().withStorageLocation(filename) - val config = CpgLoaderConfig().withOverflowConfig(odbConfig).doNotCreateIndexesOnLoad - io.shiftleft.codepropertygraph.cpgloading.CpgLoader.loadFromOverflowDb(config) - - /** Add the data flow layer to the CPG if it does not exist yet. - */ - def addDataFlowOverlayIfNonExistent(cpg: Cpg)(implicit s: Semantics): Unit = - if !cpg.metaData.overlays.exists(_ == OssDataFlow.overlayName) then - System.err.println("CPG does not have dataflow overlay. Calculating.") - val opts = new OssDataFlowOptions() - val context = new LayerCreatorContext(cpg) - new OssDataFlow(opts).run(context) - - /** Create an informational string for the user that informs of a successfully generated CPG. - */ - def newCpgCreatedString(path: String): String = - val absolutePath = File(path).path.toAbsolutePath - s"Successfully wrote graph to: $absolutePath\n" + - s"To load the graph, type `joern $absolutePath`" - - val ARGS_DELIMITER = "--frontend-args" - - /** Splits arguments at the ARGS_DELIMITER into arguments for the tool and arguments for the - * language frontend. - */ - def splitArgs(args: Array[String]): (List[String], List[String]) = - args.indexOf(ARGS_DELIMITER) match - case -1 => (args.toList, Nil) - case splitIdx => - val (parseOpts, frontendOpts) = args.toList.splitAt(splitIdx) - (parseOpts, frontendOpts.tail) // Take the tail to ignore the delimiter - - def exitIfInvalid(outDir: String, cpgFileName: String): Unit = - if File(outDir).exists then - exitWithError(s"Output directory `$outDir` already exists.") - if File(cpgFileName).notExists then - exitWithError(s"CPG at $cpgFileName does not exist.") - - def exitWithError(msg: String): Unit = - System.err.println(s"error: $msg") - System.exit(1) + /** Load code property graph from overflowDB + * + * @param filename + * name of the file that stores the CPG + */ + def loadFromOdb(filename: String): Cpg = + val odbConfig = overflowdb.Config.withDefaults().withStorageLocation(filename) + val config = CpgLoaderConfig().withOverflowConfig(odbConfig).doNotCreateIndexesOnLoad + io.shiftleft.codepropertygraph.cpgloading.CpgLoader.loadFromOverflowDb(config) + + /** Add the data flow layer to the CPG if it does not exist yet. + */ + def addDataFlowOverlayIfNonExistent(cpg: Cpg)(implicit s: Semantics): Unit = + if !cpg.metaData.overlays.exists(_ == OssDataFlow.overlayName) then + System.err.println("CPG does not have dataflow overlay. Calculating.") + val opts = new OssDataFlowOptions() + val context = new LayerCreatorContext(cpg) + new OssDataFlow(opts).run(context) + + /** Create an informational string for the user that informs of a successfully generated CPG. + */ + def newCpgCreatedString(path: String): String = + val absolutePath = File(path).path.toAbsolutePath + s"Successfully wrote graph to: $absolutePath\n" + + s"To load the graph, type `joern $absolutePath`" + + val ARGS_DELIMITER = "--frontend-args" + + /** Splits arguments at the ARGS_DELIMITER into arguments for the tool and arguments for the + * language frontend. + */ + def splitArgs(args: Array[String]): (List[String], List[String]) = + args.indexOf(ARGS_DELIMITER) match + case -1 => (args.toList, Nil) + case splitIdx => + val (parseOpts, frontendOpts) = args.toList.splitAt(splitIdx) + (parseOpts, frontendOpts.tail) // Take the tail to ignore the delimiter + + def exitIfInvalid(outDir: String, cpgFileName: String): Unit = + if File(outDir).exists then + exitWithError(s"Output directory `$outDir` already exists.") + if File(cpgFileName).notExists then + exitWithError(s"CPG at $cpgFileName does not exist.") + + def exitWithError(msg: String): Unit = + System.err.println(s"error: $msg") + System.exit(1) end CpgBasedTool diff --git a/platform/src/main/scala/io/appthreat/chencli/DefaultOverlays.scala b/platform/src/main/scala/io/appthreat/chencli/DefaultOverlays.scala index 5aaab1e7..01ab04b9 100644 --- a/platform/src/main/scala/io/appthreat/chencli/DefaultOverlays.scala +++ b/platform/src/main/scala/io/appthreat/chencli/DefaultOverlays.scala @@ -8,22 +8,22 @@ import io.shiftleft.semanticcpg.layers.LayerCreatorContext object DefaultOverlays: - val DEFAULT_CPG_IN_FILE = "cpg.bin" - val defaultMaxNumberOfDefinitions = 4000 + val DEFAULT_CPG_IN_FILE = "cpg.bin" + val defaultMaxNumberOfDefinitions = 4000 - /** Load the CPG at `storeFilename` and add enhancements, turning the CPG into an SCPG. - * - * @param storeFilename - * the filename of the cpg - */ - def create( - storeFilename: String, - maxNumberOfDefinitions: Int = defaultMaxNumberOfDefinitions - ): Cpg = - val cpg = CpgBasedTool.loadFromOdb(storeFilename) - applyDefaultOverlays(cpg) - val context = new LayerCreatorContext(cpg) - val options = new OssDataFlowOptions(maxNumberOfDefinitions) - new OssDataFlow(options).run(context) - cpg + /** Load the CPG at `storeFilename` and add enhancements, turning the CPG into an SCPG. + * + * @param storeFilename + * the filename of the cpg + */ + def create( + storeFilename: String, + maxNumberOfDefinitions: Int = defaultMaxNumberOfDefinitions + ): Cpg = + val cpg = CpgBasedTool.loadFromOdb(storeFilename) + applyDefaultOverlays(cpg) + val context = new LayerCreatorContext(cpg) + val options = new OssDataFlowOptions(maxNumberOfDefinitions) + new OssDataFlow(options).run(context) + cpg end DefaultOverlays diff --git a/platform/src/main/scala/io/appthreat/chencli/console/ChenConsole.scala b/platform/src/main/scala/io/appthreat/chencli/console/ChenConsole.scala index 7382823a..0977df30 100644 --- a/platform/src/main/scala/io/appthreat/chencli/console/ChenConsole.scala +++ b/platform/src/main/scala/io/appthreat/chencli/console/ChenConsole.scala @@ -13,38 +13,38 @@ import java.nio.file.Path object ChenWorkspaceLoader {} class ChenWorkspaceLoader extends WorkspaceLoader[ChenProject]: - override def createProject(projectFile: ProjectFile, path: Path): ChenProject = - val project = new ChenProject(projectFile, path) - project.context = EngineContext() - project + override def createProject(projectFile: ProjectFile, path: Path): ChenProject = + val project = new ChenProject(projectFile, path) + project.context = EngineContext() + project class ChenConsole extends Console[ChenProject](new ChenWorkspaceLoader): - override val config: ConsoleConfig = ChenConsole.defaultConfig + override val config: ConsoleConfig = ChenConsole.defaultConfig - implicit var semantics: Semantics = context.semantics + implicit var semantics: Semantics = context.semantics - // this is set to be `opts.ossdataflow` on initialization of the shell - var ossDataFlowOptions: OssDataFlowOptions = new OssDataFlowOptions() + // this is set to be `opts.ossdataflow` on initialization of the shell + var ossDataFlowOptions: OssDataFlowOptions = new OssDataFlowOptions() - implicit def context: EngineContext = - workspace.getActiveProject - .map(x => x.asInstanceOf[ChenProject].context) - .getOrElse(EngineContext()) + implicit def context: EngineContext = + workspace.getActiveProject + .map(x => x.asInstanceOf[ChenProject].context) + .getOrElse(EngineContext()) - def loadCpg(inputPath: String): Option[Cpg] = - report("Deprecated. Please use `importAtom` instead") - importCpg(inputPath) + def loadCpg(inputPath: String): Option[Cpg] = + report("Deprecated. Please use `importAtom` instead") + importCpg(inputPath) - override def applyDefaultOverlays(cpg: Cpg): Cpg = - super.applyDefaultOverlays(cpg) - _runAnalyzer(new OssDataFlow(ossDataFlowOptions)) + override def applyDefaultOverlays(cpg: Cpg): Cpg = + super.applyDefaultOverlays(cpg) + _runAnalyzer(new OssDataFlow(ossDataFlowOptions)) end ChenConsole object ChenConsole: - def banner(): String = - s""" + def banner(): String = + s""" | _ _ _ _ _ __ |/ |_ _ ._ ._ _. o |_ / \\ / \\ / \\ / |_|_ |\\_ | | (/_ | | | | (_| | |_) \\_/ \\_/ \\_/ / | @@ -52,7 +52,7 @@ object ChenConsole: |Version: $version """.stripMargin - def version: String = - getClass.getPackage.getImplementationVersion + def version: String = + getClass.getPackage.getImplementationVersion - def defaultConfig: ConsoleConfig = new ConsoleConfig() + def defaultConfig: ConsoleConfig = new ConsoleConfig() diff --git a/platform/src/main/scala/io/appthreat/chencli/console/Predefined.scala b/platform/src/main/scala/io/appthreat/chencli/console/Predefined.scala index f8c0ef7d..4d3a81b8 100644 --- a/platform/src/main/scala/io/appthreat/chencli/console/Predefined.scala +++ b/platform/src/main/scala/io/appthreat/chencli/console/Predefined.scala @@ -4,24 +4,24 @@ import io.appthreat.console.{Help, Run} object Predefined: - val shared: Seq[String] = - Seq( - "import _root_.io.appthreat.console._", - "import _root_.io.appthreat.chencli.console.ChenConsole._", - "import _root_.io.appthreat.chencli.console.Chen.context", - "import _root_.io.shiftleft.codepropertygraph.Cpg", - "import _root_.io.shiftleft.codepropertygraph.Cpg.docSearchPackages", - "import _root_.io.shiftleft.codepropertygraph.cpgloading._", - "import _root_.io.shiftleft.codepropertygraph.generated._", - "import _root_.io.shiftleft.codepropertygraph.generated.nodes._", - "import _root_.io.shiftleft.codepropertygraph.generated.edges._", - "import _root_.io.appthreat.dataflowengineoss.language._", - "import _root_.io.shiftleft.semanticcpg.language._", - "import overflowdb._", - "import overflowdb.traversal.{`package` => _, help => _, _}", - "import overflowdb.traversal.help.Doc", - "import scala.jdk.CollectionConverters._", - """ + val shared: Seq[String] = + Seq( + "import _root_.io.appthreat.console.*", + "import _root_.io.appthreat.chencli.console.ChenConsole.*", + "import _root_.io.appthreat.chencli.console.Chen.context", + "import _root_.io.shiftleft.codepropertygraph.Cpg", + "import _root_.io.shiftleft.codepropertygraph.Cpg.docSearchPackages", + "import _root_.io.shiftleft.codepropertygraph.cpgloading.*", + "import _root_.io.shiftleft.codepropertygraph.generated.*", + "import _root_.io.shiftleft.codepropertygraph.generated.nodes.*", + "import _root_.io.shiftleft.codepropertygraph.generated.edges.*", + "import _root_.io.appthreat.dataflowengineoss.language.*", + "import _root_.io.shiftleft.semanticcpg.language.*", + "import overflowdb.*", + "import overflowdb.traversal.{`package` => _, help => _, _}", + "import overflowdb.traversal.help.Doc", + "import scala.jdk.CollectionConverters.*", + """ |@Doc(info = "Show reachable flows from a source to sink. Default source: framework-input and sink: framework-output", example = "reachables") |def reachables(sinkTag: String, sourceTag: String, sourceTags: Array[String])(implicit atom: Cpg): Unit = { | try { @@ -73,12 +73,12 @@ object Predefined: |@Doc(info = "Show reachable flows from a source to sink. Default source: crypto-algorithm and sink: crypto-generate", example = "cryptos") |def cryptos(implicit atom: Cpg): Unit = cryptos("crypto-generate", "crypto-algorithm", Array("api", "framework", "http", "cli-source", "library-call")) |""".stripMargin - ) + ) - val forInteractiveShell: Seq[String] = - shared ++ - Seq("import _root_.io.appthreat.chencli.console.Chen._") ++ - Run.codeForRunCommand().linesIterator ++ - Help.codeForHelpCommand(classOf[ChenConsole]).linesIterator ++ - Seq("ossDataFlowOptions = opts.ossdataflow") + val forInteractiveShell: Seq[String] = + shared ++ + Seq("import _root_.io.appthreat.chencli.console.Chen._") ++ + Run.codeForRunCommand().linesIterator ++ + Help.codeForHelpCommand(classOf[ChenConsole]).linesIterator ++ + Seq("ossDataFlowOptions = opts.ossdataflow") end Predefined diff --git a/platform/src/main/scala/io/appthreat/chencli/console/ReplBridge.scala b/platform/src/main/scala/io/appthreat/chencli/console/ReplBridge.scala index ee980dd5..c46ac8db 100644 --- a/platform/src/main/scala/io/appthreat/chencli/console/ReplBridge.scala +++ b/platform/src/main/scala/io/appthreat/chencli/console/ReplBridge.scala @@ -6,18 +6,18 @@ import java.io.PrintStream object ReplBridge extends BridgeBase: - override val jProduct = ChenProduct + override val jProduct = ChenProduct - def main(args: Array[String]): Unit = - run(parseConfig(args)) + def main(args: Array[String]): Unit = + run(parseConfig(args)) - /** Code that is executed when starting the shell - */ - override def predefLines = - Predefined.forInteractiveShell + /** Code that is executed when starting the shell + */ + override def predefLines = + Predefined.forInteractiveShell - override def greeting = ChenConsole.banner() + override def greeting = ChenConsole.banner() - override def promptStr: String = "chennai" + override def promptStr: String = "chennai" - override def onExitCode: String = "workspace.projects.foreach(_.close)" + override def onExitCode: String = "workspace.projects.foreach(_.close)" diff --git a/project/FileUtils.scala b/project/FileUtils.scala index 13c04a9d..d4039208 100644 --- a/project/FileUtils.scala +++ b/project/FileUtils.scala @@ -1,6 +1,6 @@ import java.io.File import java.nio.file.Files -import scala.collection.JavaConverters._ +import scala.collection.JavaConverters.* object FileUtils { diff --git a/project/Projects.scala b/project/Projects.scala index 5c7da625..1a8d642c 100644 --- a/project/Projects.scala +++ b/project/Projects.scala @@ -1,4 +1,4 @@ -import sbt._ +import sbt.* object Projects { val frontendsRoot = file("platform/frontends") diff --git a/pyproject.toml b/pyproject.toml index f1aa4b44..1ed02a0f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "appthreat-chen" -version = "2.1.2" +version = "2.1.3" description = "Code Hierarchy Exploration Net (chen)" authors = ["Team AppThreat "] license = "Apache-2.0" diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/Overlays.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/Overlays.scala index 9e5db2a9..84914a79 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/Overlays.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/Overlays.scala @@ -8,32 +8,32 @@ import overflowdb.BatchedUpdate object Overlays: - def appendOverlayName(cpg: Cpg, overlayName: String): Unit = - new CpgPass(cpg): - override def run(diffGraph: BatchedUpdate.DiffGraphBuilder): Unit = - cpg.metaData.headOption match - case Some(metaData) => - val newValue = metaData.overlays :+ overlayName - diffGraph.setNodeProperty(metaData, Properties.OVERLAYS.name, newValue) - case None => - System.err.println("Missing metaData block") - .createAndApply() + def appendOverlayName(cpg: Cpg, overlayName: String): Unit = + new CpgPass(cpg): + override def run(diffGraph: BatchedUpdate.DiffGraphBuilder): Unit = + cpg.metaData.headOption match + case Some(metaData) => + val newValue = metaData.overlays :+ overlayName + diffGraph.setNodeProperty(metaData, Properties.OVERLAYS.name, newValue) + case None => + System.err.println("Missing metaData block") + .createAndApply() - def removeLastOverlayName(cpg: Cpg): Unit = - new CpgPass(cpg): - override def run(diffGraph: BatchedUpdate.DiffGraphBuilder): Unit = - cpg.metaData.headOption match - case Some(metaData) => - val newValue = metaData.overlays.dropRight(1) - diffGraph.setNodeProperty(metaData, Properties.OVERLAYS.name, newValue) - case None => - System.err.println("Missing metaData block") - .createAndApply() + def removeLastOverlayName(cpg: Cpg): Unit = + new CpgPass(cpg): + override def run(diffGraph: BatchedUpdate.DiffGraphBuilder): Unit = + cpg.metaData.headOption match + case Some(metaData) => + val newValue = metaData.overlays.dropRight(1) + diffGraph.setNodeProperty(metaData, Properties.OVERLAYS.name, newValue) + case None => + System.err.println("Missing metaData block") + .createAndApply() - def appliedOverlays(cpg: Cpg): Seq[String] = - cpg.metaData.headOption match - case Some(metaData) => Option(metaData.overlays).getOrElse(Nil) - case None => - System.err.println("Missing metaData block") - List() + def appliedOverlays(cpg: Cpg): Seq[String] = + cpg.metaData.headOption match + case Some(metaData) => Option(metaData.overlays).getOrElse(Nil) + case None => + System.err.println("Missing metaData block") + List() end Overlays diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/accesspath/AccessElement.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/accesspath/AccessElement.scala index 23df465b..fec4a7cd 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/accesspath/AccessElement.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/accesspath/AccessElement.scala @@ -1,31 +1,31 @@ package io.shiftleft.semanticcpg.accesspath sealed abstract class AccessElement(name: String) extends Comparable[AccessElement]: - override def toString: String = name - def kind: Int - override def hashCode(): Int = kind + name.hashCode + override def toString: String = name + def kind: Int + override def hashCode(): Int = kind + name.hashCode - override def compareTo(other: AccessElement): Int = - this.kind.compareTo(other.kind) match - case 0 => this.name.compareTo(other.toString) - case different => different + override def compareTo(other: AccessElement): Int = + this.kind.compareTo(other.kind) match + case 0 => this.name.compareTo(other.toString) + case different => different case class ConstantAccess(constant: String) extends AccessElement(constant): - override def kind: Int = 0x01010101 + override def kind: Int = 0x01010101 case object VariableAccess extends AccessElement("?"): - override def kind: Int = 0x02020202 + override def kind: Int = 0x02020202 case object VariablePointerShift extends AccessElement(""): - override def kind: Int = 0x03030303 + override def kind: Int = 0x03030303 // this will eventually get an optional extent (how many bytes wide is the memory load/store) case object IndirectionAccess extends AccessElement("*"): - override def kind: Int = 0x04040404 + override def kind: Int = 0x04040404 case object AddressOf extends AccessElement("&"): - override def kind: Int = 0x05050505 + override def kind: Int = 0x05050505 // this will eventually obtain an optional byteOffset case class PointerShift(logicalOffset: Int) extends AccessElement(s"<$logicalOffset>"): - override def kind: Int = 0x06060606 + override def kind: Int = 0x06060606 diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/accesspath/AccessPath.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/accesspath/AccessPath.scala index 1cc87564..9e6b10d1 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/accesspath/AccessPath.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/accesspath/AccessPath.scala @@ -2,232 +2,232 @@ package io.shiftleft.semanticcpg.accesspath object AccessPath: - private val empty = new AccessPath(Elements(), List[Elements]()) - - def apply(): AccessPath = empty - - def apply(elements: Elements, exclusions: Seq[Elements]): AccessPath = - if elements.isEmpty && exclusions.isEmpty then - empty - else - new AccessPath(elements, exclusions.toList) - - def isExtensionExcluded(exclusions: Seq[Elements], extension: Elements): Boolean = - exclusions.exists(e => extension.elements.startsWith(e.elements)) - - /** This class contains operations on `Elements` that are only used in the `AccessPath` class - * and are not part of the public API of `Elements` + private val empty = new AccessPath(Elements(), List[Elements]()) + + def apply(): AccessPath = empty + + def apply(elements: Elements, exclusions: Seq[Elements]): AccessPath = + if elements.isEmpty && exclusions.isEmpty then + empty + else + new AccessPath(elements, exclusions.toList) + + def isExtensionExcluded(exclusions: Seq[Elements], extension: Elements): Boolean = + exclusions.exists(e => extension.elements.startsWith(e.elements)) + + /** This class contains operations on `Elements` that are only used in the `AccessPath` class and + * are not part of the public API of `Elements` + */ + private implicit class ElementsDecorations(el: Elements): + + def noOvertaint(start: Int = 0, untilExclusive: Int = el.elements.length): Boolean = + var idx = start + while idx < untilExclusive do + el.elements(idx) match + case VariablePointerShift | VariableAccess => return false + case _ => + idx += 1 + true + + /** In all sane situations, invertibleTailLength is 0 or 1: + * - we don't expect &, because you cannot take the address of pointer+i (can only take + * address of rvalue) + * - we don't expect & : The & should have collapsed against a preceding *. An example + * where this occurs is (&(ptr->field))[1], which becomes ptr: * field & <2> * This reads + * the _next_ field: It doesn't alias with ptr->field at all, but reads the next bytes in + * the struct, after field. + * + * Such code is very un-idiomatic. */ - private implicit class ElementsDecorations(el: Elements): - - def noOvertaint(start: Int = 0, untilExclusive: Int = el.elements.length): Boolean = - var idx = start - while idx < untilExclusive do - el.elements(idx) match - case VariablePointerShift | VariableAccess => return false - case _ => - idx += 1 - true - - /** In all sane situations, invertibleTailLength is 0 or 1: - * - we don't expect &, because you cannot take the address of pointer+i (can only - * take address of rvalue) - * - we don't expect & : The & should have collapsed against a preceding *. An example - * where this occurs is (&(ptr->field))[1], which becomes ptr: * field & <2> * This - * reads the _next_ field: It doesn't alias with ptr->field at all, but reads the next - * bytes in the struct, after field. - * - * Such code is very un-idiomatic. - */ - def invertibleTailLength: Int = - var i = 0 - val nElements = el.elements.length - 1 - while nElements - i > -1 do - el.elements(nElements - i) match - case AddressOf | VariablePointerShift | _: PointerShift => i += 1 - case _ => return i - i - end ElementsDecorations + def invertibleTailLength: Int = + var i = 0 + val nElements = el.elements.length - 1 + while nElements - i > -1 do + el.elements(nElements - i) match + case AddressOf | VariablePointerShift | _: PointerShift => i += 1 + case _ => return i + i + end ElementsDecorations end AccessPath case class AccessPath(elements: Elements, exclusions: Seq[Elements]): - import AccessPath.* - - def isEmpty: Boolean = this == AccessPath.empty - - private var cachedHash: Int = 0 - - override def hashCode(): Int = - if cachedHash == 0 then - val computedHash = elements.hashCode() + exclusions.hashCode() ^ 0x404f92ab - cachedHash = if computedHash == 0 then 1 else computedHash - cachedHash - else cachedHash - - // for handling of invertible elements, cf AccessPathAlgebra.md - // FIXME: may need to process invertible tail of `other` better - - def ++(other: Elements): Option[AccessPath] = - if isExtensionExcluded(other) then None - else Some(AccessPath(this.elements ++ other, this.truncateExclusions(other).exclusions)) - - // FIXME: may need to process invertible tail of `other` better - def ++(other: AccessPath): Option[AccessPath] = - (this ++ other.elements).map { appended => - other.exclusions.foldLeft(appended) { case (ap, ex) => ap.addExclusion(ex) } - } - - def matchFull(other: AccessPath): FullMatchResult = - val res = this.matchFull(other.elements) - if - res.extensionDiff.isEmpty && res.stepIntoPath.isDefined && other.isExtensionExcluded( - res.stepIntoPath.get.elements - ) - then - FullMatchResult(Some(this), None, Elements.empty) - else res - - def matchFull(other: Elements): FullMatchResult = - val (matchRes, matchDiff) = this.matchAndDiff(other) - matchRes match - case MatchResult.NO_MATCH => - FullMatchResult(Some(this), None, Elements.empty) - case MatchResult.PREFIX_MATCH | MatchResult.EXACT_MATCH => - FullMatchResult(None, Some(AccessPath(matchDiff, this.exclusions)), Elements.empty) - case MatchResult.VARIABLE_PREFIX_MATCH | MatchResult.VARIABLE_EXACT_MATCH => - FullMatchResult( - Some(this), - Some(AccessPath(matchDiff, this.exclusions)), - Elements.empty - ) - case MatchResult.EXTENDED_MATCH => - FullMatchResult( - Some(this.addExclusion(matchDiff)), - Some(AccessPath(Elements.empty, exclusions).truncateExclusions(matchDiff)), - matchDiff - ) - case MatchResult.VARIABLE_EXTENDED_MATCH => - FullMatchResult( - Some(this), - Some(AccessPath(Elements.empty, exclusions).truncateExclusions(matchDiff)), - matchDiff - ) - end match - end matchFull - - def matchAndDiff(other: Elements): (MatchResult.MatchResult, Elements) = - val thisTail = elements.invertibleTailLength - val otherTail = other.invertibleTailLength - val thisHead = elements.elements.length - thisTail - val otherHead = other.elements.length - otherTail - - val cmpUntil = scala.math.min(thisHead, otherHead) - var idx = 0 - var overTainted = false - while idx < cmpUntil do - (elements.elements(idx), other.elements(idx)) match - case (VariableAccess, VariableAccess) | (_: ConstantAccess, VariableAccess) | - (VariableAccess, _: ConstantAccess) | ( - VariablePointerShift, - VariablePointerShift - ) | - (_: PointerShift, VariablePointerShift) | ( - VariablePointerShift, - _: PointerShift - ) => - overTainted = true - case (thisElem, otherElem) => - if thisElem != otherElem then - return (MatchResult.NO_MATCH, Elements.empty) + import AccessPath.* + + def isEmpty: Boolean = this == AccessPath.empty + + private var cachedHash: Int = 0 + + override def hashCode(): Int = + if cachedHash == 0 then + val computedHash = elements.hashCode() + exclusions.hashCode() ^ 0x404f92ab + cachedHash = if computedHash == 0 then 1 else computedHash + cachedHash + else cachedHash + + // for handling of invertible elements, cf AccessPathAlgebra.md + // FIXME: may need to process invertible tail of `other` better + + def ++(other: Elements): Option[AccessPath] = + if isExtensionExcluded(other) then None + else Some(AccessPath(this.elements ++ other, this.truncateExclusions(other).exclusions)) + + // FIXME: may need to process invertible tail of `other` better + def ++(other: AccessPath): Option[AccessPath] = + (this ++ other.elements).map { appended => + other.exclusions.foldLeft(appended) { case (ap, ex) => ap.addExclusion(ex) } + } + + def matchFull(other: AccessPath): FullMatchResult = + val res = this.matchFull(other.elements) + if + res.extensionDiff.isEmpty && res.stepIntoPath.isDefined && other.isExtensionExcluded( + res.stepIntoPath.get.elements + ) + then + FullMatchResult(Some(this), None, Elements.empty) + else res + + def matchFull(other: Elements): FullMatchResult = + val (matchRes, matchDiff) = this.matchAndDiff(other) + matchRes match + case MatchResult.NO_MATCH => + FullMatchResult(Some(this), None, Elements.empty) + case MatchResult.PREFIX_MATCH | MatchResult.EXACT_MATCH => + FullMatchResult(None, Some(AccessPath(matchDiff, this.exclusions)), Elements.empty) + case MatchResult.VARIABLE_PREFIX_MATCH | MatchResult.VARIABLE_EXACT_MATCH => + FullMatchResult( + Some(this), + Some(AccessPath(matchDiff, this.exclusions)), + Elements.empty + ) + case MatchResult.EXTENDED_MATCH => + FullMatchResult( + Some(this.addExclusion(matchDiff)), + Some(AccessPath(Elements.empty, exclusions).truncateExclusions(matchDiff)), + matchDiff + ) + case MatchResult.VARIABLE_EXTENDED_MATCH => + FullMatchResult( + Some(this), + Some(AccessPath(Elements.empty, exclusions).truncateExclusions(matchDiff)), + matchDiff + ) + end match + end matchFull + + def matchAndDiff(other: Elements): (MatchResult.MatchResult, Elements) = + val thisTail = elements.invertibleTailLength + val otherTail = other.invertibleTailLength + val thisHead = elements.elements.length - thisTail + val otherHead = other.elements.length - otherTail + + val cmpUntil = scala.math.min(thisHead, otherHead) + var idx = 0 + var overTainted = false + while idx < cmpUntil do + (elements.elements(idx), other.elements(idx)) match + case (VariableAccess, VariableAccess) | (_: ConstantAccess, VariableAccess) | + (VariableAccess, _: ConstantAccess) | ( + VariablePointerShift, + VariablePointerShift + ) | + (_: PointerShift, VariablePointerShift) | ( + VariablePointerShift, + _: PointerShift + ) => + overTainted = true + case (thisElem, otherElem) => + if thisElem != otherElem then + return (MatchResult.NO_MATCH, Elements.empty) + idx += 1 + var done = false + + /** We now try to greedily match more elements. We know that one of the two paths will only + * contain invertible elements. The issue is the following: prefix <1> & x prefix & With + * greedy matching, we end up with a diff: x. If we just did the invert-append algorithm, we + * would end up with a less precise diff: * <1> & x == * & x. + */ + val minlen = scala.math.min(elements.elements.length, other.elements.length) + while !done && idx < minlen do + (elements.elements(idx), other.elements(idx)) match + case (_: PointerShift, VariablePointerShift) | ( + VariablePointerShift, + _: PointerShift + ) | + (VariablePointerShift, VariablePointerShift) => + overTainted = true idx += 1 - var done = false - - /** We now try to greedily match more elements. We know that one of the two paths will only - * contain invertible elements. The issue is the following: prefix <1> & x prefix & - * With greedy matching, we end up with a diff: x. If we just did the invert-append - * algorithm, we would end up with a less precise diff: * <1> & x == * & x. - */ - val minlen = scala.math.min(elements.elements.length, other.elements.length) - while !done && idx < minlen do - (elements.elements(idx), other.elements(idx)) match - case (_: PointerShift, VariablePointerShift) | ( - VariablePointerShift, - _: PointerShift - ) | - (VariablePointerShift, VariablePointerShift) => - overTainted = true - idx += 1 - case (thisElem, otherElem) => - if thisElem == otherElem then - idx += 1 - else - done = true - if thisHead >= otherHead then - // prefix or exact - val diff = Elements.inverted(other.elements.drop(idx)) ++ Elements.unnormalized( - elements.elements.drop(idx) + case (thisElem, otherElem) => + if thisElem == otherElem then + idx += 1 + else + done = true + if thisHead >= otherHead then + // prefix or exact + val diff = Elements.inverted(other.elements.drop(idx)) ++ Elements.unnormalized( + elements.elements.drop(idx) + ) + + /** we don't need to overtaint if thisTail has variable PointerShift: They can still get + * excluded e.g. suppose we track "a" "b" and encounter "a" <4>. + */ + overTainted |= !other.noOvertaint(otherHead) + if !overTainted & thisHead == otherHead then (MatchResult.EXACT_MATCH, diff) + else if overTainted && thisHead == otherHead then + (MatchResult.VARIABLE_EXACT_MATCH, diff) + else if !overTainted && thisHead != otherHead then (MatchResult.PREFIX_MATCH, diff) + else if overTainted && thisHead != otherHead then + (MatchResult.VARIABLE_PREFIX_MATCH, diff) + else throw new RuntimeException() + else + // extended + val diff = Elements.inverted(elements.elements.drop(idx)) ++ Elements.unnormalized( + other.elements.drop(idx) + ) + + /** we need to overtaint if any either otherTail or thisTail has variable PointerShift: e.g. + * suppose we track "a" <4> and encounter "a" "b" "c", or suppose that we track "a" + * and encounter "a" "b" + */ + overTainted |= !elements.noOvertaint(thisHead) | !other.noOvertaint(otherHead) + + if overTainted then (MatchResult.VARIABLE_EXTENDED_MATCH, diff) + else if isExtensionExcluded(diff) then (MatchResult.NO_MATCH, Elements.empty) + else (MatchResult.EXTENDED_MATCH, diff) + end if + end matchAndDiff + + private def truncateExclusions(compareExclusion: Elements): AccessPath = + if exclusions.isEmpty then return this + val size = compareExclusion.elements.length + val newExclusions = + exclusions + .filter(_.elements.startsWith(compareExclusion.elements)) + .map(exclusion => Elements.normalized(exclusion.elements.drop(size))) + .sorted + AccessPath(elements, newExclusions) + + private def addExclusion(newExclusion: Elements): AccessPath = + if newExclusion.noOvertaint() then + val ex = + Elements.unnormalized( + newExclusion.elements.dropRight(newExclusion.invertibleTailLength) ) + if isExtensionExcluded(ex) then return this + val unshadowed = exclusions.filter(!_.elements.startsWith(ex.elements)) + AccessPath(elements, (unshadowed :+ ex).sorted) + else this - /** we don't need to overtaint if thisTail has variable PointerShift: They can still get - * excluded e.g. suppose we track "a" "b" and encounter "a" <4>. - */ - overTainted |= !other.noOvertaint(otherHead) - if !overTainted & thisHead == otherHead then (MatchResult.EXACT_MATCH, diff) - else if overTainted && thisHead == otherHead then - (MatchResult.VARIABLE_EXACT_MATCH, diff) - else if !overTainted && thisHead != otherHead then (MatchResult.PREFIX_MATCH, diff) - else if overTainted && thisHead != otherHead then - (MatchResult.VARIABLE_PREFIX_MATCH, diff) - else throw new RuntimeException() - else - // extended - val diff = Elements.inverted(elements.elements.drop(idx)) ++ Elements.unnormalized( - other.elements.drop(idx) - ) - - /** we need to overtaint if any either otherTail or thisTail has variable PointerShift: - * e.g. suppose we track "a" <4> and encounter "a" "b" "c", or suppose that we - * track "a" and encounter "a" "b" - */ - overTainted |= !elements.noOvertaint(thisHead) | !other.noOvertaint(otherHead) - - if overTainted then (MatchResult.VARIABLE_EXTENDED_MATCH, diff) - else if isExtensionExcluded(diff) then (MatchResult.NO_MATCH, Elements.empty) - else (MatchResult.EXTENDED_MATCH, diff) - end if - end matchAndDiff - - private def truncateExclusions(compareExclusion: Elements): AccessPath = - if exclusions.isEmpty then return this - val size = compareExclusion.elements.length - val newExclusions = - exclusions - .filter(_.elements.startsWith(compareExclusion.elements)) - .map(exclusion => Elements.normalized(exclusion.elements.drop(size))) - .sorted - AccessPath(elements, newExclusions) - - private def addExclusion(newExclusion: Elements): AccessPath = - if newExclusion.noOvertaint() then - val ex = - Elements.unnormalized( - newExclusion.elements.dropRight(newExclusion.invertibleTailLength) - ) - if isExtensionExcluded(ex) then return this - val unshadowed = exclusions.filter(!_.elements.startsWith(ex.elements)) - AccessPath(elements, (unshadowed :+ ex).sorted) - else this - - def isExtensionExcluded(extension: Elements): Boolean = - AccessPath.isExtensionExcluded(this.exclusions, extension) + def isExtensionExcluded(extension: Elements): Boolean = + AccessPath.isExtensionExcluded(this.exclusions, extension) end AccessPath sealed trait MatchResult object MatchResult extends Enumeration: - type MatchResult = Value - val NO_MATCH, EXACT_MATCH, VARIABLE_EXACT_MATCH, PREFIX_MATCH, VARIABLE_PREFIX_MATCH, - EXTENDED_MATCH, VARIABLE_EXTENDED_MATCH = Value + type MatchResult = Value + val NO_MATCH, EXACT_MATCH, VARIABLE_EXACT_MATCH, PREFIX_MATCH, VARIABLE_PREFIX_MATCH, + EXTENDED_MATCH, VARIABLE_EXTENDED_MATCH = Value /** Result of `matchFull` comparison * @@ -257,7 +257,7 @@ case class FullMatchResult( stepIntoPath: Option[AccessPath], extensionDiff: Elements ): - def hasMatch: Boolean = stepIntoPath.nonEmpty + def hasMatch: Boolean = stepIntoPath.nonEmpty /** For handling of invertible elements, cf AccessPathAlgebra.md. The general rule is that elements * concatenate normally, except for: @@ -292,160 +292,160 @@ case class FullMatchResult( // Elements.empty object Elements: - val empty = new Elements() - - def apply(): Elements = empty - - def normalized(elems: IterableOnce[AccessElement]): Elements = - destructiveNormalized(elems.iterator.toArray) - - def normalized(elems: AccessElement*): Elements = - destructiveNormalized(elems.toArray) - - def unnormalized(elems: IterableOnce[AccessElement]): Elements = - newIfNonEmpty(elems.iterator.toArray) - - def newIfNonEmpty(elems: Array[AccessElement]): Elements = - if !elems.isEmpty then new Elements(elems) - else empty - - def inverted(elems: Iterable[AccessElement]): Elements = - val invertedElems: Array[AccessElement] = elems.toArray.reverse.map { - case AddressOf => IndirectionAccess - case IndirectionAccess => AddressOf - case PointerShift(idx) => PointerShift(-idx) - case VariablePointerShift => VariablePointerShift - case _ => throw new RuntimeException(s"Cannot invert ${Elements.unnormalized(elems)}") - } - newIfNonEmpty(invertedElems) - - def noOvertaint(elems: Iterable[AccessElement]): Boolean = - elems.forall(_ != VariableAccess) - - private def destructiveNormalized(elems: Array[AccessElement]): Elements = - var idxRight = 0 - var idxLeft = -1 - while idxRight < elems.length do - val nextE = elems(idxRight) - nextE match - case shift: PointerShift if shift.logicalOffset == 0 => - // nothing to do + val empty = new Elements() + + def apply(): Elements = empty + + def normalized(elems: IterableOnce[AccessElement]): Elements = + destructiveNormalized(elems.iterator.toArray) + + def normalized(elems: AccessElement*): Elements = + destructiveNormalized(elems.toArray) + + def unnormalized(elems: IterableOnce[AccessElement]): Elements = + newIfNonEmpty(elems.iterator.toArray) + + def newIfNonEmpty(elems: Array[AccessElement]): Elements = + if !elems.isEmpty then new Elements(elems) + else empty + + def inverted(elems: Iterable[AccessElement]): Elements = + val invertedElems: Array[AccessElement] = elems.toArray.reverse.map { + case AddressOf => IndirectionAccess + case IndirectionAccess => AddressOf + case PointerShift(idx) => PointerShift(-idx) + case VariablePointerShift => VariablePointerShift + case _ => throw new RuntimeException(s"Cannot invert ${Elements.unnormalized(elems)}") + } + newIfNonEmpty(invertedElems) + + def noOvertaint(elems: Iterable[AccessElement]): Boolean = + elems.forall(_ != VariableAccess) + + private def destructiveNormalized(elems: Array[AccessElement]): Elements = + var idxRight = 0 + var idxLeft = -1 + while idxRight < elems.length do + val nextE = elems(idxRight) + nextE match + case shift: PointerShift if shift.logicalOffset == 0 => + // nothing to do + case _ => + if idxLeft == -1 then + idxLeft = 0 + elems(0) = nextE + else + val lastE = elems(idxLeft) + (lastE, nextE) match + case (last: PointerShift, next: PointerShift) => + val newShift = last.logicalOffset + next.logicalOffset + if newShift != 0 then elems(idxLeft) = PointerShift(newShift) + else idxLeft -= 1 + case (VariablePointerShift, _: PointerShift) | ( + VariablePointerShift, + VariablePointerShift + ) => + case (_: PointerShift, VariablePointerShift) => + elems(idxLeft) = VariablePointerShift + case (AddressOf, IndirectionAccess) => + idxLeft -= 1 + case (IndirectionAccess, AddressOf) => + idxLeft -= 1 // WRONG but useful, cf comment for `Elements.:+` case _ => - if idxLeft == -1 then - idxLeft = 0 - elems(0) = nextE - else - val lastE = elems(idxLeft) - (lastE, nextE) match - case (last: PointerShift, next: PointerShift) => - val newShift = last.logicalOffset + next.logicalOffset - if newShift != 0 then elems(idxLeft) = PointerShift(newShift) - else idxLeft -= 1 - case (VariablePointerShift, _: PointerShift) | ( - VariablePointerShift, - VariablePointerShift - ) => - case (_: PointerShift, VariablePointerShift) => - elems(idxLeft) = VariablePointerShift - case (AddressOf, IndirectionAccess) => - idxLeft -= 1 - case (IndirectionAccess, AddressOf) => - idxLeft -= 1 // WRONG but useful, cf comment for `Elements.:+` - case _ => - idxLeft += 1 - elems(idxLeft) = nextE - end match - idxRight += 1 - end while - newIfNonEmpty(elems.take(idxLeft + 1)) - end destructiveNormalized + idxLeft += 1 + elems(idxLeft) = nextE + end match + idxRight += 1 + end while + newIfNonEmpty(elems.take(idxLeft + 1)) + end destructiveNormalized end Elements final class Elements(val elements: Array[AccessElement] = Array[AccessElement]()) extends Comparable[Elements]: - def isEmpty: Boolean = elements.isEmpty - - override def toString: String = s"Elements(${elements.mkString(",")})" - - override def equals(other: Any): Boolean = - other match - case otherElements: Elements => - Array.equals( - elements.asInstanceOf[Array[AnyRef]], - otherElements.elements.asInstanceOf[Array[AnyRef]] - ) - case _ => false - - override def hashCode(): Int = java.util.Arrays.hashCode(elements.asInstanceOf[Array[AnyRef]]) - - override def compareTo(other: Elements): Int = - val until = scala.math.min(elements.length, other.elements.length) - var idx = 0 - while idx < until do - elements(idx).compareTo(other.elements(idx)) match - case 0 => - case difference => return difference - idx += 1 - if idx < elements.length then +1 - else if idx < other.elements.length then -1 - else 0 - - def ++(otherElements: Elements): Elements = + def isEmpty: Boolean = elements.isEmpty - if elements.isEmpty then return otherElements - if otherElements.isEmpty then return this + override def toString: String = s"Elements(${elements.mkString(",")})" - var buf = None: Option[AccessElement] - val otherSize = otherElements.elements.length - var idx = 0 - val until = scala.math.min(elements.length, otherSize) - var done = false - - while idx < until & !done do - (elements(elements.length - idx - 1), otherElements.elements(idx)) match - case (AddressOf, IndirectionAccess) => - idx += 1 - case (IndirectionAccess, AddressOf) => - // WRONG but useful, cf comment for `Elements.:+` - idx += 1 - case (VariablePointerShift, VariablePointerShift) | ( - _: PointerShift, - VariablePointerShift - ) | - (VariablePointerShift, _: PointerShift) => - done = true - buf = Some(VariablePointerShift) - idx += 1 - case (last: PointerShift, first: PointerShift) => - val newOffset = last.logicalOffset + first.logicalOffset - if newOffset != 0 then - done = true - buf = Some(PointerShift(newOffset)) - idx += 1 - case _ => - done = true - end while - val sz = elements.length + otherSize - 2 * idx + (if buf.isDefined then 1 else 0) - val res = Array.fill(sz) { null }: Array[AccessElement] - elements.copyToArray(res, 0, elements.length - idx) - if buf.isDefined then - res(elements.length - idx) = buf.get - java.lang.System.arraycopy( - otherElements.elements, - idx, - res, - elements.length - idx + 1, - otherSize - idx + override def equals(other: Any): Boolean = + other match + case otherElements: Elements => + Array.equals( + elements.asInstanceOf[Array[AnyRef]], + otherElements.elements.asInstanceOf[Array[AnyRef]] ) - else - java.lang.System.arraycopy( - otherElements.elements, - idx, - res, - elements.length - idx, - otherSize - idx - ) - Elements.newIfNonEmpty(res) - end ++ + case _ => false + + override def hashCode(): Int = java.util.Arrays.hashCode(elements.asInstanceOf[Array[AnyRef]]) + + override def compareTo(other: Elements): Int = + val until = scala.math.min(elements.length, other.elements.length) + var idx = 0 + while idx < until do + elements(idx).compareTo(other.elements(idx)) match + case 0 => + case difference => return difference + idx += 1 + if idx < elements.length then +1 + else if idx < other.elements.length then -1 + else 0 + + def ++(otherElements: Elements): Elements = + + if elements.isEmpty then return otherElements + if otherElements.isEmpty then return this + + var buf = None: Option[AccessElement] + val otherSize = otherElements.elements.length + var idx = 0 + val until = scala.math.min(elements.length, otherSize) + var done = false + + while idx < until & !done do + (elements(elements.length - idx - 1), otherElements.elements(idx)) match + case (AddressOf, IndirectionAccess) => + idx += 1 + case (IndirectionAccess, AddressOf) => + // WRONG but useful, cf comment for `Elements.:+` + idx += 1 + case (VariablePointerShift, VariablePointerShift) | ( + _: PointerShift, + VariablePointerShift + ) | + (VariablePointerShift, _: PointerShift) => + done = true + buf = Some(VariablePointerShift) + idx += 1 + case (last: PointerShift, first: PointerShift) => + val newOffset = last.logicalOffset + first.logicalOffset + if newOffset != 0 then + done = true + buf = Some(PointerShift(newOffset)) + idx += 1 + case _ => + done = true + end while + val sz = elements.length + otherSize - 2 * idx + (if buf.isDefined then 1 else 0) + val res = Array.fill(sz) { null }: Array[AccessElement] + elements.copyToArray(res, 0, elements.length - idx) + if buf.isDefined then + res(elements.length - idx) = buf.get + java.lang.System.arraycopy( + otherElements.elements, + idx, + res, + elements.length - idx + 1, + otherSize - idx + ) + else + java.lang.System.arraycopy( + otherElements.elements, + idx, + res, + elements.length - idx, + otherSize - idx + ) + Elements.newIfNonEmpty(res) + end ++ end Elements diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/accesspath/TrackedBase.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/accesspath/TrackedBase.scala index 6a66cd22..56e9a1c1 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/accesspath/TrackedBase.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/accesspath/TrackedBase.scala @@ -5,30 +5,30 @@ import io.shiftleft.codepropertygraph.generated.nodes.* trait TrackedBase case class TrackedNamedVariable(name: String) extends TrackedBase case class TrackedReturnValue(call: CallRepr) extends TrackedBase: - override def toString: String = - s"TrackedReturnValue(${call.code})" + override def toString: String = + s"TrackedReturnValue(${call.code})" case class TrackedLiteral(literal: Literal) extends TrackedBase: - override def toString: String = - s"TrackedLiteral(${literal.code})" + override def toString: String = + s"TrackedLiteral(${literal.code})" sealed trait TrackedMethodOrTypeRef extends TrackedBase: - def code: String + def code: String - override def toString: String = - s"TrackedMethodOrTypeRef($code)" + override def toString: String = + s"TrackedMethodOrTypeRef($code)" case class TrackedMethod(method: MethodRef) extends TrackedMethodOrTypeRef: - override def code: String = method.code + override def code: String = method.code case class TrackedTypeRef(typeRef: TypeRef) extends TrackedMethodOrTypeRef: - override def code: String = typeRef.code + override def code: String = typeRef.code case class TrackedAlias(argIndex: Int) extends TrackedBase: - override def toString: String = - s"TrackedAlias($argIndex)" + override def toString: String = + s"TrackedAlias($argIndex)" object TrackedUnknown extends TrackedBase: - override def toString: String = - "TrackedUnknown" + override def toString: String = + "TrackedUnknown" object TrackedFormalReturn extends TrackedBase: - override def toString: String = - "TrackedFormalReturn" + override def toString: String = + "TrackedFormalReturn" diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/codedumper/CodeDumper.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/codedumper/CodeDumper.scala index 2ec9ebf9..290c5582 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/codedumper/CodeDumper.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/codedumper/CodeDumper.scala @@ -11,112 +11,112 @@ import scala.util.{Failure, Success, Try} object CodeDumper: - private val logger: Logger = LoggerFactory.getLogger(getClass) + private val logger: Logger = LoggerFactory.getLogger(getClass) - def arrow(locationFullName: Option[String] = None): CharSequence = - s"/* <=== ${locationFullName.getOrElse("")} */ " + def arrow(locationFullName: Option[String] = None): CharSequence = + s"/* <=== ${locationFullName.getOrElse("")} */ " - private val supportedLanguages = - Set( - Languages.C, - Languages.NEWC, - Languages.GHIDRA, - Languages.JAVA, - Languages.JAVASRC, - Languages.JSSRC, - Languages.PYTHON, - Languages.PYTHONSRC - ) + private val supportedLanguages = + Set( + Languages.C, + Languages.NEWC, + Languages.GHIDRA, + Languages.JAVA, + Languages.JAVASRC, + Languages.JSSRC, + Languages.PYTHON, + Languages.PYTHONSRC + ) - private def toAbsolutePath(path: String, rootPath: String): String = - val absolutePath = Paths.get(path) match - case p if p.isAbsolute => p - case _ if rootPath.endsWith(path) => Paths.get(rootPath) - case p => Paths.get(rootPath, p.toString) - absolutePath.normalize().toString + private def toAbsolutePath(path: String, rootPath: String): String = + val absolutePath = Paths.get(path) match + case p if p.isAbsolute => p + case _ if rootPath.endsWith(path) => Paths.get(rootPath) + case p => Paths.get(rootPath, p.toString) + absolutePath.normalize().toString - /** Dump string representation of code at given `location`. - */ - def dump( - location: NewLocation, - language: Option[String], - rootPath: Option[String], - highlight: Boolean, - withArrow: Boolean = true - ): String = - (location.node, language) match - case (None, _) => - logger.warn("Empty `location.node` encountered") - "" - case (_, None) => - logger.debug("dump not supported; language not set in CPG") - "" - case (_, Some(lang)) if !supportedLanguages.contains(lang) => - logger.debug(s"dump not supported for language '$lang'") - "" - case (Some(node), Some(lang)) => - val method: Option[Method] = node match - case n: Method => Some(n) - case n: Expression => Some(n.method) - case n: Local => n.method.headOption - case _ => None - method - .collect { - case m: Method if m.lineNumber.isDefined && m.lineNumberEnd.isDefined => - val rawCode = if lang == Languages.GHIDRA || lang == Languages.JAVA then - val lines = m.code.split("\n") - lines.zipWithIndex - .map { case (line, lineNo) => - if lineNo == 0 && withArrow then - s"$line ${arrow(Option(m.fullName))}" - else - line - } - .mkString("\n") - else - val filename = rootPath.map( - toAbsolutePath(location.filename, _) - ).getOrElse(location.filename) - code( - filename, - m.lineNumber.get, - m.lineNumberEnd.get, - location.lineNumber, - Option(m.fullName) - ) - if highlight then - SourceHighlighter.highlight(Source(rawCode, lang)) - else - Some(rawCode) - } - .flatten - .getOrElse("") - - /** For a given `filename`, `startLine`, and `endLine`, return the corresponding code by reading - * it from the file. If `lineToHighlight` is defined, then a line containing an arrow (as a - * source code comment) is included right before that line. - */ - def code( - filename: String, - startLine: Integer, - endLine: Integer, - lineToHighlight: Option[Integer] = None, - locationFullName: Option[String] = None - ): String = - Try(IOUtils.readLinesInFile(Paths.get(filename))) match - case Failure(exception) => - logger.warn(s"error reading from: '$filename'", exception) - "" - case Success(lines) => - lines - .slice(startLine - 1, endLine) - .zipWithIndex - .map { case (line, lineNo) => - if lineToHighlight.isDefined && lineNo == lineToHighlight.get - startLine - then - s"$line ${arrow(locationFullName)}" + /** Dump string representation of code at given `location`. + */ + def dump( + location: NewLocation, + language: Option[String], + rootPath: Option[String], + highlight: Boolean, + withArrow: Boolean = true + ): String = + (location.node, language) match + case (None, _) => + logger.warn("Empty `location.node` encountered") + "" + case (_, None) => + logger.debug("dump not supported; language not set in CPG") + "" + case (_, Some(lang)) if !supportedLanguages.contains(lang) => + logger.debug(s"dump not supported for language '$lang'") + "" + case (Some(node), Some(lang)) => + val method: Option[Method] = node match + case n: Method => Some(n) + case n: Expression => Some(n.method) + case n: Local => n.method.headOption + case _ => None + method + .collect { + case m: Method if m.lineNumber.isDefined && m.lineNumberEnd.isDefined => + val rawCode = if lang == Languages.GHIDRA || lang == Languages.JAVA then + val lines = m.code.split("\n") + lines.zipWithIndex + .map { case (line, lineNo) => + if lineNo == 0 && withArrow then + s"$line ${arrow(Option(m.fullName))}" + else + line + } + .mkString("\n") + else + val filename = rootPath.map( + toAbsolutePath(location.filename, _) + ).getOrElse(location.filename) + code( + filename, + m.lineNumber.get, + m.lineNumberEnd.get, + location.lineNumber, + Option(m.fullName) + ) + if highlight then + SourceHighlighter.highlight(Source(rawCode, lang)) else - line - } - .mkString("\n") + Some(rawCode) + } + .flatten + .getOrElse("") + + /** For a given `filename`, `startLine`, and `endLine`, return the corresponding code by reading + * it from the file. If `lineToHighlight` is defined, then a line containing an arrow (as a + * source code comment) is included right before that line. + */ + def code( + filename: String, + startLine: Integer, + endLine: Integer, + lineToHighlight: Option[Integer] = None, + locationFullName: Option[String] = None + ): String = + Try(IOUtils.readLinesInFile(Paths.get(filename))) match + case Failure(exception) => + logger.warn(s"error reading from: '$filename'", exception) + "" + case Success(lines) => + lines + .slice(startLine - 1, endLine) + .zipWithIndex + .map { case (line, lineNo) => + if lineToHighlight.isDefined && lineNo == lineToHighlight.get - startLine + then + s"$line ${arrow(locationFullName)}" + else + line + } + .mkString("\n") end CodeDumper diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/codedumper/SourceHighlighter.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/codedumper/SourceHighlighter.scala index 15bdc869..e2a38bcf 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/codedumper/SourceHighlighter.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/codedumper/SourceHighlighter.scala @@ -12,32 +12,32 @@ import scala.sys.process.Process case class Source(code: String, language: String) object SourceHighlighter: - private val logger: Logger = LoggerFactory.getLogger(SourceHighlighter.getClass) + private val logger: Logger = LoggerFactory.getLogger(SourceHighlighter.getClass) - def highlight(source: Source): Option[String] = - val langFlag = source.language match - case Languages.C | Languages.NEWC | Languages.GHIDRA => "-sC" - case Languages.JAVA | Languages.JAVASRC => "-sJava" - case Languages.JSSRC | Languages.JAVASCRIPT => "-sJavascript" - case Languages.PYTHON | Languages.PYTHONSRC => "-sPython" - case other => throw new RuntimeException( - s"Attempting to call highlighter on unsupported language: $other" - ) + def highlight(source: Source): Option[String] = + val langFlag = source.language match + case Languages.C | Languages.NEWC | Languages.GHIDRA => "-sC" + case Languages.JAVA | Languages.JAVASRC => "-sJava" + case Languages.JSSRC | Languages.JAVASCRIPT => "-sJavascript" + case Languages.PYTHON | Languages.PYTHONSRC => "-sPython" + case other => throw new RuntimeException( + s"Attempting to call highlighter on unsupported language: $other" + ) - val tmpSrcFile = File.newTemporaryFile("dump") - tmpSrcFile.writeText(source.code) - try - val highlightedCode = - Process(Seq("source-highlight-esc.sh", tmpSrcFile.path.toString, langFlag)).!! - Some(highlightedCode) - catch - case exception: Exception => - logger.debug( - "syntax highlighting not working. Is `source-highlight` installed?", - exception - ) - Some(source.code) - finally - tmpSrcFile.delete() - end highlight + val tmpSrcFile = File.newTemporaryFile("dump") + tmpSrcFile.writeText(source.code) + try + val highlightedCode = + Process(Seq("source-highlight-esc.sh", tmpSrcFile.path.toString, langFlag)).!! + Some(highlightedCode) + catch + case exception: Exception => + logger.debug( + "syntax highlighting not working. Is `source-highlight` installed?", + exception + ) + Some(source.code) + finally + tmpSrcFile.delete() + end highlight end SourceHighlighter diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/AstGenerator.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/AstGenerator.scala index 9aeef7c0..f485d135 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/AstGenerator.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/AstGenerator.scala @@ -7,14 +7,14 @@ import io.shiftleft.semanticcpg.language.* class AstGenerator: - private val edgeType = EdgeTypes.AST + private val edgeType = EdgeTypes.AST - def generate(astRoot: AstNode): Graph = - def shouldBeDisplayed(v: AstNode): Boolean = !v.isInstanceOf[MethodParameterOut] - val vertices = astRoot.ast.filter(shouldBeDisplayed).l - val edges = vertices.flatMap(v => - v.astChildren.filter(shouldBeDisplayed).map { child => - Edge(v, child, edgeType = edgeType) - } - ) - Graph(vertices, edges) + def generate(astRoot: AstNode): Graph = + def shouldBeDisplayed(v: AstNode): Boolean = !v.isInstanceOf[MethodParameterOut] + val vertices = astRoot.ast.filter(shouldBeDisplayed).l + val edges = vertices.flatMap(v => + v.astChildren.filter(shouldBeDisplayed).map { child => + Edge(v, child, edgeType = edgeType) + } + ) + Graph(vertices, edges) diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/CallGraphGenerator.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/CallGraphGenerator.scala index 63b366a9..c94a85b4 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/CallGraphGenerator.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/CallGraphGenerator.scala @@ -9,30 +9,30 @@ import scala.collection.mutable class CallGraphGenerator: - def generate(cpg: Cpg): Graph = - val subgraph = mutable.HashMap.empty[String, Seq[StoredNode]] - val vertices = cpg.method.l - val edges = - for - srcMethod <- vertices - _ = storeInSubgraph(srcMethod, subgraph) - child <- srcMethod.call - tgt <- child.callOut - yield - storeInSubgraph(tgt, subgraph) - Edge(srcMethod, tgt, label = child.dispatchType.stripSuffix("_DISPATCH")) - Graph(vertices, edges.distinct, subgraph.toMap) + def generate(cpg: Cpg): Graph = + val subgraph = mutable.HashMap.empty[String, Seq[StoredNode]] + val vertices = cpg.method.l + val edges = + for + srcMethod <- vertices + _ = storeInSubgraph(srcMethod, subgraph) + child <- srcMethod.call + tgt <- child.callOut + yield + storeInSubgraph(tgt, subgraph) + Edge(srcMethod, tgt, label = child.dispatchType.stripSuffix("_DISPATCH")) + Graph(vertices, edges.distinct, subgraph.toMap) - def storeInSubgraph(method: Method, subgraph: mutable.Map[String, Seq[StoredNode]]): Unit = - method._typeDeclViaAstIn match - case Some(typeDeclName) => - subgraph.put( - typeDeclName.fullName, - subgraph.getOrElse(typeDeclName.fullName, Seq()) ++ Seq(method) - ) - case None => - subgraph.put( - method.astParentFullName, - subgraph.getOrElse(method.astParentFullName, Seq()) ++ Seq(method) - ) + def storeInSubgraph(method: Method, subgraph: mutable.Map[String, Seq[StoredNode]]): Unit = + method._typeDeclViaAstIn match + case Some(typeDeclName) => + subgraph.put( + typeDeclName.fullName, + subgraph.getOrElse(typeDeclName.fullName, Seq()) ++ Seq(method) + ) + case None => + subgraph.put( + method.astParentFullName, + subgraph.getOrElse(method.astParentFullName, Seq()) ++ Seq(method) + ) end CallGraphGenerator diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/CdgGenerator.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/CdgGenerator.scala index fe581877..297ef337 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/CdgGenerator.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/CdgGenerator.scala @@ -8,7 +8,7 @@ import scala.jdk.CollectionConverters.* class CdgGenerator extends CfgGenerator: - override val edgeType: String = EdgeTypes.CDG + override val edgeType: String = EdgeTypes.CDG - override def expand(v: StoredNode): Iterator[Edge] = - v._cdgOut.map(node => Edge(v, node, edgeType = edgeType)) + override def expand(v: StoredNode): Iterator[Edge] = + v._cdgOut.map(node => Edge(v, node, edgeType = edgeType)) diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/CfgGenerator.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/CfgGenerator.scala index aaaae82d..35c2420f 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/CfgGenerator.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/CfgGenerator.scala @@ -8,55 +8,55 @@ import overflowdb.Node class CfgGenerator: - val edgeType: String = EdgeTypes.CFG - - def generate(methodNode: Method): Graph = - val vertices = methodNode.cfgNode.l ++ List( - methodNode, - methodNode.methodReturn - ) ++ methodNode.parameter.l - val verticesToDisplay = vertices.filter(cfgNodeShouldBeDisplayed) - - def edgesToDisplay(srcNode: StoredNode, visited: List[StoredNode] = List()): List[Edge] = - if visited.contains(srcNode) then - List() - else - val children = expand(srcNode).filter(x => vertices.contains(x.dst)) - val (visible, invisible) = children.partition(x => cfgNodeShouldBeDisplayed(x.dst)) - visible.toList ++ invisible.toList.flatMap { n => - edgesToDisplay(n.dst, visited ++ List(srcNode)).map(y => - Edge(srcNode, y.dst, edgeType = edgeType) - ) - } - - val edges = verticesToDisplay.flatMap { v => - edgesToDisplay(v) - }.distinct - - val allIdsReferencedByEdges = edges.flatMap { edge => - Set(edge.src.id, edge.dst.id) - } - - Graph( - verticesToDisplay - .filter(node => allIdsReferencedByEdges.contains(node.id)), - edges - ) - end generate - - protected def expand(v: StoredNode): Iterator[Edge] = - v._cfgOut.map(node => Edge(v, node, edgeType = edgeType)) - - private def isConditionInControlStructure(v: Node): Boolean = v match - case id: Identifier => id.astParent.isControlStructure - case _ => false - - private def cfgNodeShouldBeDisplayed(v: Node): Boolean = - isConditionInControlStructure(v) || - !(v.isInstanceOf[Literal] || - v.isInstanceOf[Identifier] || - v.isInstanceOf[Block] || - v.isInstanceOf[ControlStructure] || - v.isInstanceOf[JumpTarget] || - v.isInstanceOf[MethodParameterIn]) + val edgeType: String = EdgeTypes.CFG + + def generate(methodNode: Method): Graph = + val vertices = methodNode.cfgNode.l ++ List( + methodNode, + methodNode.methodReturn + ) ++ methodNode.parameter.l + val verticesToDisplay = vertices.filter(cfgNodeShouldBeDisplayed) + + def edgesToDisplay(srcNode: StoredNode, visited: List[StoredNode] = List()): List[Edge] = + if visited.contains(srcNode) then + List() + else + val children = expand(srcNode).filter(x => vertices.contains(x.dst)) + val (visible, invisible) = children.partition(x => cfgNodeShouldBeDisplayed(x.dst)) + visible.toList ++ invisible.toList.flatMap { n => + edgesToDisplay(n.dst, visited ++ List(srcNode)).map(y => + Edge(srcNode, y.dst, edgeType = edgeType) + ) + } + + val edges = verticesToDisplay.flatMap { v => + edgesToDisplay(v) + }.distinct + + val allIdsReferencedByEdges = edges.flatMap { edge => + Set(edge.src.id, edge.dst.id) + } + + Graph( + verticesToDisplay + .filter(node => allIdsReferencedByEdges.contains(node.id)), + edges + ) + end generate + + protected def expand(v: StoredNode): Iterator[Edge] = + v._cfgOut.map(node => Edge(v, node, edgeType = edgeType)) + + private def isConditionInControlStructure(v: Node): Boolean = v match + case id: Identifier => id.astParent.isControlStructure + case _ => false + + private def cfgNodeShouldBeDisplayed(v: Node): Boolean = + isConditionInControlStructure(v) || + !(v.isInstanceOf[Literal] || + v.isInstanceOf[Identifier] || + v.isInstanceOf[Block] || + v.isInstanceOf[ControlStructure] || + v.isInstanceOf[JumpTarget] || + v.isInstanceOf[MethodParameterIn]) end CfgGenerator diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/DotAstGenerator.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/DotAstGenerator.scala index 9aa44011..a65eb892 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/DotAstGenerator.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/DotAstGenerator.scala @@ -4,9 +4,9 @@ import io.shiftleft.codepropertygraph.generated.nodes.AstNode object DotAstGenerator: - def dotAst[T <: AstNode](traversal: Iterator[T]): Iterator[String] = - traversal.map(dotAst) + def dotAst[T <: AstNode](traversal: Iterator[T]): Iterator[String] = + traversal.map(dotAst) - def dotAst(astRoot: AstNode): String = - val ast = new AstGenerator().generate(astRoot) - DotSerializer.dotGraph(Option(astRoot), ast) + def dotAst(astRoot: AstNode): String = + val ast = new AstGenerator().generate(astRoot) + DotSerializer.dotGraph(Option(astRoot), ast) diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/DotCallGraphGenerator.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/DotCallGraphGenerator.scala index 33932599..0af6f789 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/DotCallGraphGenerator.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/DotCallGraphGenerator.scala @@ -4,6 +4,6 @@ import io.shiftleft.codepropertygraph.Cpg object DotCallGraphGenerator: - def dotCallGraph(cpg: Cpg): Iterator[String] = - val callGraph = new CallGraphGenerator().generate(cpg) - Iterator(DotSerializer.dotGraph(None, callGraph)) + def dotCallGraph(cpg: Cpg): Iterator[String] = + val callGraph = new CallGraphGenerator().generate(cpg) + Iterator(DotSerializer.dotGraph(None, callGraph)) diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/DotCdgGenerator.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/DotCdgGenerator.scala index 833aedb3..04eeb63a 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/DotCdgGenerator.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/DotCdgGenerator.scala @@ -4,9 +4,9 @@ import io.shiftleft.codepropertygraph.generated.nodes.Method object DotCdgGenerator: - def dotCdg(traversal: Iterator[Method]): Iterator[String] = - traversal.map(dotCdg) + def dotCdg(traversal: Iterator[Method]): Iterator[String] = + traversal.map(dotCdg) - def dotCdg(method: Method): String = - val cdg = new CdgGenerator().generate(method) - DotSerializer.dotGraph(Option(method), cdg) + def dotCdg(method: Method): String = + val cdg = new CdgGenerator().generate(method) + DotSerializer.dotGraph(Option(method), cdg) diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/DotCfgGenerator.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/DotCfgGenerator.scala index e9f15e87..3c48d27e 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/DotCfgGenerator.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/DotCfgGenerator.scala @@ -4,9 +4,9 @@ import io.shiftleft.codepropertygraph.generated.nodes.Method object DotCfgGenerator: - def dotCfg(traversal: Iterator[Method]): Iterator[String] = - traversal.map(dotCfg) + def dotCfg(traversal: Iterator[Method]): Iterator[String] = + traversal.map(dotCfg) - def dotCfg(method: Method): String = - val cfg = new CfgGenerator().generate(method) - DotSerializer.dotGraph(Option(method), cfg) + def dotCfg(method: Method): String = + val cfg = new CfgGenerator().generate(method) + DotSerializer.dotGraph(Option(method), cfg) diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/DotSerializer.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/DotSerializer.scala index a5798064..2a3b7497 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/DotSerializer.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/DotSerializer.scala @@ -12,137 +12,137 @@ import scala.language.postfixOps object DotSerializer: - private val charLimit = 50 - - case class Graph( - vertices: List[StoredNode], - edges: List[Edge], - subgraph: Map[String, Seq[StoredNode]] = HashMap.empty[String, Seq[StoredNode]] - ): - - def ++(other: Graph): Graph = - Graph((this.vertices ++ other.vertices).distinct, (this.edges ++ other.edges).distinct) - - case class Edge( - src: StoredNode, - dst: StoredNode, - srcVisible: Boolean = true, - label: String = "", - edgeType: String = "" + private val charLimit = 50 + + case class Graph( + vertices: List[StoredNode], + edges: List[Edge], + subgraph: Map[String, Seq[StoredNode]] = HashMap.empty[String, Seq[StoredNode]] + ): + + def ++(other: Graph): Graph = + Graph((this.vertices ++ other.vertices).distinct, (this.edges ++ other.edges).distinct) + + case class Edge( + src: StoredNode, + dst: StoredNode, + srcVisible: Boolean = true, + label: String = "", + edgeType: String = "" + ) + + def dotGraph( + root: Option[AstNode] = None, + graph: Graph, + withEdgeTypes: Boolean = false + ): String = + val sb = root match + case Some(r) => namedGraphBegin(r) + case None => defaultGraphBegin() + val nodeStrings = graph.vertices.map(nodeToDot) + val edgeStrings = graph.edges.map(e => edgeToDot(e, withEdgeTypes)) + val subgraphStrings = graph.subgraph.zipWithIndex.map { case ((subgraph, nodes), idx) => + nodesToSubGraphs(subgraph, nodes, idx) + } + sb.append((nodeStrings ++ edgeStrings ++ subgraphStrings).mkString("\n")) + graphEnd(sb) + + private def namedGraphBegin(root: AstNode): mutable.StringBuilder = + val sb = new mutable.StringBuilder + val name = escape(root match + case method: Method => method.name + case _ => "" ) - - def dotGraph( - root: Option[AstNode] = None, - graph: Graph, - withEdgeTypes: Boolean = false - ): String = - val sb = root match - case Some(r) => namedGraphBegin(r) - case None => defaultGraphBegin() - val nodeStrings = graph.vertices.map(nodeToDot) - val edgeStrings = graph.edges.map(e => edgeToDot(e, withEdgeTypes)) - val subgraphStrings = graph.subgraph.zipWithIndex.map { case ((subgraph, nodes), idx) => - nodesToSubGraphs(subgraph, nodes, idx) - } - sb.append((nodeStrings ++ edgeStrings ++ subgraphStrings).mkString("\n")) - graphEnd(sb) - - private def namedGraphBegin(root: AstNode): mutable.StringBuilder = - val sb = new mutable.StringBuilder - val name = escape(root match - case method: Method => method.name - case _ => "" - ) - sb.append(s"""digraph "$name" { \n""") - - private def defaultGraphBegin(): mutable.StringBuilder = - val sb = new mutable.StringBuilder - val name = "CPG" - sb.append(s"""digraph "$name" { \n""") - - private def limit(str: String): String = if str.length > charLimit then - s"${str.take(charLimit - 3)}..." + sb.append(s"""digraph "$name" { \n""") + + private def defaultGraphBegin(): mutable.StringBuilder = + val sb = new mutable.StringBuilder + val name = "CPG" + sb.append(s"""digraph "$name" { \n""") + + private def limit(str: String): String = if str.length > charLimit then + s"${str.take(charLimit - 3)}..." + else + str + + private def stringRepr(vertex: StoredNode): String = + val maybeLineNo: Optional[AnyRef] = vertex.propertyOption(PropertyNames.LINE_NUMBER) + escape(vertex match + case call: Call => (call.name, limit(call.code)).toString + case contrl: ControlStructure => + (contrl.label, contrl.controlStructureType, contrl.code).toString + case expr: Expression => + (expr.label, limit(expr.code), limit(toCfgNode(expr).code)).toString + case method: Method => (method.label, method.name).toString + case ret: MethodReturn => (ret.label, ret.typeFullName).toString + case param: MethodParameterIn => ("PARAM", param.code).toString + case local: Local => (local.label, s"${local.code}: ${local.typeFullName}").toString + case target: JumpTarget => (target.label, target.name).toString + case modifier: Modifier => (modifier.label, modifier.modifierType).toString() + case annoAssign: AnnotationParameterAssign => + (annoAssign.label, annoAssign.code).toString() + case annoParam: AnnotationParameter => (annoParam.label, annoParam.code).toString() + case typ: Type => (typ.label, typ.name).toString() + case typeDecl: TypeDecl => (typeDecl.label, typeDecl.name).toString() + case member: Member => (member.label, member.name).toString() + case _ => "" + ) + (if maybeLineNo.isPresent then s"${maybeLineNo.get()}" else "") + end stringRepr + + private def toCfgNode(node: StoredNode): CfgNode = + node match + case node: Identifier => node.parentExpression.get + case node: MethodRef => node.parentExpression.get + case node: Literal => node.parentExpression.get + case node: MethodParameterIn => node.method + case node: MethodParameterOut => node.method.methodReturn + case node: Call if MemberAccess.isGenericMemberAccessName(node.name) => + node.parentExpression.get + case node: CallRepr => node + case node: MethodReturn => node + case node: Expression => node + + private def nodeToDot(node: StoredNode): String = + s""""${node.id}" [label = <${stringRepr(node)}> ]""".stripMargin + + private def edgeToDot(edge: Edge, withEdgeTypes: Boolean): String = + val edgeLabel = if withEdgeTypes then + edge.edgeType + ": " + escape(edge.label) else - str - - private def stringRepr(vertex: StoredNode): String = - val maybeLineNo: Optional[AnyRef] = vertex.propertyOption(PropertyNames.LINE_NUMBER) - escape(vertex match - case call: Call => (call.name, limit(call.code)).toString - case contrl: ControlStructure => - (contrl.label, contrl.controlStructureType, contrl.code).toString - case expr: Expression => - (expr.label, limit(expr.code), limit(toCfgNode(expr).code)).toString - case method: Method => (method.label, method.name).toString - case ret: MethodReturn => (ret.label, ret.typeFullName).toString - case param: MethodParameterIn => ("PARAM", param.code).toString - case local: Local => (local.label, s"${local.code}: ${local.typeFullName}").toString - case target: JumpTarget => (target.label, target.name).toString - case modifier: Modifier => (modifier.label, modifier.modifierType).toString() - case annoAssign: AnnotationParameterAssign => - (annoAssign.label, annoAssign.code).toString() - case annoParam: AnnotationParameter => (annoParam.label, annoParam.code).toString() - case typ: Type => (typ.label, typ.name).toString() - case typeDecl: TypeDecl => (typeDecl.label, typeDecl.name).toString() - case member: Member => (member.label, member.name).toString() - case _ => "" - ) + (if maybeLineNo.isPresent then s"${maybeLineNo.get()}" else "") - end stringRepr - - private def toCfgNode(node: StoredNode): CfgNode = - node match - case node: Identifier => node.parentExpression.get - case node: MethodRef => node.parentExpression.get - case node: Literal => node.parentExpression.get - case node: MethodParameterIn => node.method - case node: MethodParameterOut => node.method.methodReturn - case node: Call if MemberAccess.isGenericMemberAccessName(node.name) => - node.parentExpression.get - case node: CallRepr => node - case node: MethodReturn => node - case node: Expression => node - - private def nodeToDot(node: StoredNode): String = - s""""${node.id}" [label = <${stringRepr(node)}> ]""".stripMargin - - private def edgeToDot(edge: Edge, withEdgeTypes: Boolean): String = - val edgeLabel = if withEdgeTypes then - edge.edgeType + ": " + escape(edge.label) - else - escape(edge.label) - val labelStr = - Some(s""" [ label = "$edgeLabel"] """).filter(_ => edgeLabel != "").getOrElse("") - s""" "${edge.src.id}" -> "${edge.dst.id}" """ + labelStr - - def nodesToSubGraphs(subgraph: String, children: Seq[StoredNode], idx: Int): String = - val escapedName = escape(subgraph) - val childString = children.map { c => s" \"${c.id()}\";" }.mkString("\n") - s""" subgraph cluster_$idx { + escape(edge.label) + val labelStr = + Some(s""" [ label = "$edgeLabel"] """).filter(_ => edgeLabel != "").getOrElse("") + s""" "${edge.src.id}" -> "${edge.dst.id}" """ + labelStr + + def nodesToSubGraphs(subgraph: String, children: Seq[StoredNode], idx: Int): String = + val escapedName = escape(subgraph) + val childString = children.map { c => s" \"${c.id()}\";" }.mkString("\n") + s""" subgraph cluster_$idx { |$childString | label = "$escapedName"; | } |""".stripMargin - /** Escapes common characters that do not conform to HTML character sets. - * @see - * https://www.w3.org/TR/html4/sgml/entities.html - */ - private def escapedChar(ch: Char): String = ch match - case '"' => """ - case '<' => "<" - case '>' => ">" - case '&' => "&" - case _ => - if ch.isControl then "\\0" + Integer.toOctalString(ch.toInt) - else String.valueOf(ch) - - private def escape(str: String): String = - if str == null then - "" - else - str.flatMap(escapedChar) - - private def graphEnd(sb: mutable.StringBuilder): String = - sb.append("\n}\n") - sb.toString + /** Escapes common characters that do not conform to HTML character sets. + * @see + * https://www.w3.org/TR/html4/sgml/entities.html + */ + private def escapedChar(ch: Char): String = ch match + case '"' => """ + case '<' => "<" + case '>' => ">" + case '&' => "&" + case _ => + if ch.isControl then "\\0" + Integer.toOctalString(ch.toInt) + else String.valueOf(ch) + + private def escape(str: String): String = + if str == null then + "" + else + str.flatMap(escapedChar) + + private def graphEnd(sb: mutable.StringBuilder): String = + sb.append("\n}\n") + sb.toString end DotSerializer diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/DotTypeHierarchyGenerator.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/DotTypeHierarchyGenerator.scala index aa3bfa9f..f8bc1021 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/DotTypeHierarchyGenerator.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/DotTypeHierarchyGenerator.scala @@ -4,6 +4,6 @@ import io.shiftleft.codepropertygraph.Cpg object DotTypeHierarchyGenerator: - def dotTypeHierarchy(cpg: Cpg): Iterator[String] = - val typeHierarchy = new TypeHierarchyGenerator().generate(cpg) - Iterator(DotSerializer.dotGraph(None, typeHierarchy)) + def dotTypeHierarchy(cpg: Cpg): Iterator[String] = + val typeHierarchy = new TypeHierarchyGenerator().generate(cpg) + Iterator(DotSerializer.dotGraph(None, typeHierarchy)) diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/TypeHierarchyGenerator.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/TypeHierarchyGenerator.scala index bdd1dc1e..0dca69c5 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/TypeHierarchyGenerator.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/dotgenerator/TypeHierarchyGenerator.scala @@ -9,37 +9,37 @@ import scala.collection.mutable class TypeHierarchyGenerator: - def generate(cpg: Cpg): Graph = - val subgraph = mutable.HashMap.empty[String, Seq[StoredNode]] - val vertices = cpg.typeDecl.l - val typeToIsExternal = vertices.map { t => t.fullName -> t.isExternal }.toMap - val edges = - for - srcTypeDecl <- vertices - srcType <- srcTypeDecl._typeViaRefIn.l - _ = storeInSubgraph(srcType, subgraph, typeToIsExternal) - tgtType <- srcTypeDecl.inheritsFromOut - yield - storeInSubgraph(tgtType, subgraph, typeToIsExternal) - Edge(tgtType, srcType) - Graph(vertices.flatMap(_._typeViaRefIn.l), edges.distinct, subgraph.toMap) + def generate(cpg: Cpg): Graph = + val subgraph = mutable.HashMap.empty[String, Seq[StoredNode]] + val vertices = cpg.typeDecl.l + val typeToIsExternal = vertices.map { t => t.fullName -> t.isExternal }.toMap + val edges = + for + srcTypeDecl <- vertices + srcType <- srcTypeDecl._typeViaRefIn.l + _ = storeInSubgraph(srcType, subgraph, typeToIsExternal) + tgtType <- srcTypeDecl.inheritsFromOut + yield + storeInSubgraph(tgtType, subgraph, typeToIsExternal) + Edge(tgtType, srcType) + Graph(vertices.flatMap(_._typeViaRefIn.l), edges.distinct, subgraph.toMap) - def storeInSubgraph( - typ: Type, - subgraph: mutable.Map[String, Seq[StoredNode]], - typeToIsExternal: Map[String, Boolean] - ): Unit = - if !typeToIsExternal(typ.fullName) then - /* + def storeInSubgraph( + typ: Type, + subgraph: mutable.Map[String, Seq[StoredNode]], + typeToIsExternal: Map[String, Boolean] + ): Unit = + if !typeToIsExternal(typ.fullName) then + /* We parse the namespace information instead of looking at the namespace node as types such as inner classes may not be attached to a namespace block - */ - val namespace = - if typ.fullName.contains(".") then - typ.fullName.stripSuffix(s".${typ.name}") - else - typ.fullName.stripSuffix(s"${typ.name}") - subgraph.put(namespace, subgraph.getOrElse(namespace, Seq()) ++ Seq(typ)) - else - subgraph.put("", subgraph.getOrElse("", Seq()) ++ Seq(typ)) + */ + val namespace = + if typ.fullName.contains(".") then + typ.fullName.stripSuffix(s".${typ.name}") + else + typ.fullName.stripSuffix(s"${typ.name}") + subgraph.put(namespace, subgraph.getOrElse(namespace, Seq()) ++ Seq(typ)) + else + subgraph.put("", subgraph.getOrElse("", Seq()) ++ Seq(typ)) end TypeHierarchyGenerator diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/AccessPathHandling.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/AccessPathHandling.scala index 743bd805..da60876c 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/AccessPathHandling.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/AccessPathHandling.scala @@ -9,134 +9,134 @@ import scala.jdk.CollectionConverters.IteratorHasAsScala object AccessPathHandling: - def leafToTrackedBaseAndAccessPathInternal(node: StoredNode) - : Option[(TrackedBase, List[AccessElement])] = - node match - case node: MethodParameterIn => Some((TrackedNamedVariable(node.name), Nil)) - case node: MethodParameterOut => Some((TrackedNamedVariable(node.name), Nil)) - case node: Identifier => Some((TrackedNamedVariable(node.name), Nil)) - case node: Literal => Some((TrackedLiteral(node), Nil)) - case node: MethodRef => Some((TrackedMethod(node), Nil)) - case node: TypeRef => Some((TrackedTypeRef(node), Nil)) - case _: Return => Some((TrackedFormalReturn, Nil)) - case _: MethodReturn => Some((TrackedFormalReturn, Nil)) - case _: Unknown => Some((TrackedUnknown, Nil)) - case _: ControlStructure => Some((TrackedUnknown, Nil)) - // FieldIdentifiers are only fake arguments, hence should not be tracked - case _: FieldIdentifier => Some((TrackedUnknown, Nil)) - case _ => None + def leafToTrackedBaseAndAccessPathInternal(node: StoredNode) + : Option[(TrackedBase, List[AccessElement])] = + node match + case node: MethodParameterIn => Some((TrackedNamedVariable(node.name), Nil)) + case node: MethodParameterOut => Some((TrackedNamedVariable(node.name), Nil)) + case node: Identifier => Some((TrackedNamedVariable(node.name), Nil)) + case node: Literal => Some((TrackedLiteral(node), Nil)) + case node: MethodRef => Some((TrackedMethod(node), Nil)) + case node: TypeRef => Some((TrackedTypeRef(node), Nil)) + case _: Return => Some((TrackedFormalReturn, Nil)) + case _: MethodReturn => Some((TrackedFormalReturn, Nil)) + case _: Unknown => Some((TrackedUnknown, Nil)) + case _: ControlStructure => Some((TrackedUnknown, Nil)) + // FieldIdentifiers are only fake arguments, hence should not be tracked + case _: FieldIdentifier => Some((TrackedUnknown, Nil)) + case _ => None - private val logger = LoggerFactory.getLogger(getClass) - private var hasWarnedDeprecations = false + private val logger = LoggerFactory.getLogger(getClass) + private var hasWarnedDeprecations = false - def memberAccessToPath(memberAccess: Call, tail: List[AccessElement]): List[AccessElement] = - memberAccess.name match - case Operators.memberAccess | Operators.indirectMemberAccess => - if !hasWarnedDeprecations then - logger.debug(s"Deprecated Operator ${memberAccess.name} on $memberAccess") - hasWarnedDeprecations = true - memberAccess - .argumentOption(2) - .collect { - case node: Literal => ConstantAccess(node.code) - case node: Identifier => ConstantAccess(node.name) - case other if other.propertyOption(PropertyNames.NAME).isPresent => - logger.warn( - s"unexpected/deprecated node encountered: $other with properties: ${other.propertiesMap()}" - ) - ConstantAccess(other.property(Properties.NAME)) - } - .getOrElse(VariableAccess) :: tail + def memberAccessToPath(memberAccess: Call, tail: List[AccessElement]): List[AccessElement] = + memberAccess.name match + case Operators.memberAccess | Operators.indirectMemberAccess => + if !hasWarnedDeprecations then + logger.debug(s"Deprecated Operator ${memberAccess.name} on $memberAccess") + hasWarnedDeprecations = true + memberAccess + .argumentOption(2) + .collect { + case node: Literal => ConstantAccess(node.code) + case node: Identifier => ConstantAccess(node.name) + case other if other.propertyOption(PropertyNames.NAME).isPresent => + logger.warn( + s"unexpected/deprecated node encountered: $other with properties: ${other.propertiesMap()}" + ) + ConstantAccess(other.property(Properties.NAME)) + } + .getOrElse(VariableAccess) :: tail - case Operators.computedMemberAccess | Operators.indirectComputedMemberAccess => - if !hasWarnedDeprecations then - logger.debug(s"Deprecated Operator ${memberAccess.name} on $memberAccess") - hasWarnedDeprecations = true - memberAccess - .argumentOption(2) - .collect { case lit: Literal => - ConstantAccess(lit.code) - } - .getOrElse(VariableAccess) :: tail - case Operators.indirection => - IndirectionAccess :: tail - case Operators.addressOf => - AddressOf :: tail - case Operators.fieldAccess => - extractAccessStringTokenForFieldAccess(memberAccess) :: tail - case Operators.indexAccess => - extractAccessStringToken(memberAccess) :: tail - case Operators.indirectFieldAccess => - // we will reverse the list in the end - extractAccessStringTokenForFieldAccess(memberAccess) :: IndirectionAccess :: tail - case Operators.indirectIndexAccess => - // we will reverse the list in the end - IndirectionAccess :: extractAccessIntToken(memberAccess) :: tail - case Operators.pointerShift => - extractAccessIntToken(memberAccess) :: tail - case Operators.getElementPtr => - // we will reverse the list in the end - AddressOf :: extractAccessStringTokenForFieldAccess( - memberAccess - ) :: IndirectionAccess :: tail + case Operators.computedMemberAccess | Operators.indirectComputedMemberAccess => + if !hasWarnedDeprecations then + logger.debug(s"Deprecated Operator ${memberAccess.name} on $memberAccess") + hasWarnedDeprecations = true + memberAccess + .argumentOption(2) + .collect { case lit: Literal => + ConstantAccess(lit.code) + } + .getOrElse(VariableAccess) :: tail + case Operators.indirection => + IndirectionAccess :: tail + case Operators.addressOf => + AddressOf :: tail + case Operators.fieldAccess => + extractAccessStringTokenForFieldAccess(memberAccess) :: tail + case Operators.indexAccess => + extractAccessStringToken(memberAccess) :: tail + case Operators.indirectFieldAccess => + // we will reverse the list in the end + extractAccessStringTokenForFieldAccess(memberAccess) :: IndirectionAccess :: tail + case Operators.indirectIndexAccess => + // we will reverse the list in the end + IndirectionAccess :: extractAccessIntToken(memberAccess) :: tail + case Operators.pointerShift => + extractAccessIntToken(memberAccess) :: tail + case Operators.getElementPtr => + // we will reverse the list in the end + AddressOf :: extractAccessStringTokenForFieldAccess( + memberAccess + ) :: IndirectionAccess :: tail - private def extractAccessStringTokenForFieldAccess(memberAccess: Call): AccessElement = - memberAccess.argumentOption(2) match - case None => - logger.warn( - s"Invalid AST: Found member access without second argument." + - s" Member access CODE: ${memberAccess.code}" + - s" In method ${memberAccess.method.fullName}" - ) - VariableAccess - case Some(literal: Literal) => ConstantAccess(literal.code) - case Some(fieldIdentifier: FieldIdentifier) => - ConstantAccess(fieldIdentifier.canonicalName) - case Some(identifier: Identifier) => - // TODO remove this case. - // This is handling a very old CPG format version where IDENTIFIER was used instead of FIELD_IDENTIFIER. - // Sadly we need this for now to support a GO cpg. - ConstantAccess(identifier.name) - case _ => VariableAccess + private def extractAccessStringTokenForFieldAccess(memberAccess: Call): AccessElement = + memberAccess.argumentOption(2) match + case None => + logger.warn( + s"Invalid AST: Found member access without second argument." + + s" Member access CODE: ${memberAccess.code}" + + s" In method ${memberAccess.method.fullName}" + ) + VariableAccess + case Some(literal: Literal) => ConstantAccess(literal.code) + case Some(fieldIdentifier: FieldIdentifier) => + ConstantAccess(fieldIdentifier.canonicalName) + case Some(identifier: Identifier) => + // TODO remove this case. + // This is handling a very old CPG format version where IDENTIFIER was used instead of FIELD_IDENTIFIER. + // Sadly we need this for now to support a GO cpg. + ConstantAccess(identifier.name) + case _ => VariableAccess - private def extractAccessStringToken(memberAccess: Call): AccessElement = - memberAccess.argumentOption(2) match - case None => - logger.warn( - s"Invalid AST: Found member access without second argument." + - s" Member access CODE: ${memberAccess.code}" + - s" In method ${memberAccess.method.fullName}" - ) - VariableAccess - case Some(literal: Literal) => ConstantAccess(literal.code) - case Some(fieldIdentifier: FieldIdentifier) => - ConstantAccess(fieldIdentifier.canonicalName) - case _ => VariableAccess + private def extractAccessStringToken(memberAccess: Call): AccessElement = + memberAccess.argumentOption(2) match + case None => + logger.warn( + s"Invalid AST: Found member access without second argument." + + s" Member access CODE: ${memberAccess.code}" + + s" In method ${memberAccess.method.fullName}" + ) + VariableAccess + case Some(literal: Literal) => ConstantAccess(literal.code) + case Some(fieldIdentifier: FieldIdentifier) => + ConstantAccess(fieldIdentifier.canonicalName) + case _ => VariableAccess - private def extractAccessIntToken(memberAccess: Call): AccessElement = - memberAccess.argumentOption(2) match - case None => - logger.warn( - s"Invalid AST: Found member access without second argument." + - s" Member access CODE: ${memberAccess.code}" + - s" In method ${memberAccess.method.fullName}" - ) - VariablePointerShift - case Some(literal: Literal) => - literal.code.toIntOption.map(PointerShift.apply).getOrElse(VariablePointerShift) - case Some(fieldIdentifier: FieldIdentifier) => - fieldIdentifier.canonicalName.toIntOption - .map(PointerShift.apply) - .getOrElse(VariablePointerShift) - case _ => VariablePointerShift + private def extractAccessIntToken(memberAccess: Call): AccessElement = + memberAccess.argumentOption(2) match + case None => + logger.warn( + s"Invalid AST: Found member access without second argument." + + s" Member access CODE: ${memberAccess.code}" + + s" In method ${memberAccess.method.fullName}" + ) + VariablePointerShift + case Some(literal: Literal) => + literal.code.toIntOption.map(PointerShift.apply).getOrElse(VariablePointerShift) + case Some(fieldIdentifier: FieldIdentifier) => + fieldIdentifier.canonicalName.toIntOption + .map(PointerShift.apply) + .getOrElse(VariablePointerShift) + case _ => VariablePointerShift - def lastExpressionInBlock(block: Block): Option[Expression] = - block._astOut - .collect { - case node: Expression if !node.isInstanceOf[Local] && !node.isInstanceOf[Method] => - node - } - .toVector - .sortBy(_.order) - .lastOption + def lastExpressionInBlock(block: Block): Option[Expression] = + block._astOut + .collect { + case node: Expression if !node.isInstanceOf[Local] && !node.isInstanceOf[Method] => + node + } + .toVector + .sortBy(_.order) + .lastOption end AccessPathHandling diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/HasLocation.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/HasLocation.scala index ba31a3df..43927e76 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/HasLocation.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/HasLocation.scala @@ -3,4 +3,4 @@ package io.shiftleft.semanticcpg.language import io.shiftleft.codepropertygraph.generated.nodes.NewLocation trait HasLocation extends Any: - def location: NewLocation + def location: NewLocation diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/ICallResolver.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/ICallResolver.scala index 83e0d12e..f03d46a6 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/ICallResolver.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/ICallResolver.scala @@ -7,75 +7,75 @@ import scala.jdk.CollectionConverters.* trait ICallResolver: - def getUnresolvedMethodFullNames(callsite: CallRepr): Iterable[String] = - triggerCallsiteResolution(callsite) - getUnresolvedMethodFullNamesInternal(callsite) - - def getUnresolvedMethodFullNamesInternal(callsite: CallRepr): Iterable[String] - - /** Get methods called at the given callsite. This internally calls triggerCallsiteResolution. - */ - def getCalledMethods(callsite: CallRepr): Iterable[Method] = - val combined = mutable.ArrayBuffer.empty[Method] - if callsite.nonEmpty then - triggerCallsiteResolution(callsite) - callsite._callOut.foreach(method => combined.append(method.asInstanceOf[Method])) - combined.appendAll(getResolvedCalledMethods(callsite)) - combined - - /** Same as getCalledMethods but with traversal return type. - */ - def getCalledMethodsAsTraversal(callsite: CallRepr): Iterator[Method] = - getCalledMethods(callsite).iterator - - /** Get callsites of the given method. This internally calls triggerMethodResolution. - */ - def getMethodCallsites(method: Method): Iterable[CallRepr] = - triggerMethodCallsiteResolution(method) - // The same call sites of a method can be found via static and dynamic lookup. - // This is for example the case for Java virtual call sites which are statically assert - // a certain method which could be overriden. If we are looking for the call sites of - // such a statically asserted method, we find it twice and thus deduplicate here. - val combined = mutable.LinkedHashSet.empty[CallRepr] - method._callIn.foreach(call => combined.add(call.asInstanceOf[CallRepr])) - combined.addAll(getResolvedMethodCallsites(method)) - - combined.toBuffer - - /** Same as getMethodCallsites but with traversal return type. - */ - def getMethodCallsitesAsTraversal(method: Method): Iterator[CallRepr] = - getMethodCallsites(method).iterator - - /** Starts data flow tracking to find all method which could be called at the given callsite. - * The result is stored in the resolver internal cache. - */ - def triggerCallsiteResolution(callsite: CallRepr): Unit - - /** Starts data flow tracking to find all callsites which could call the given method. The - * result is stored in the resolver internal cache. - */ - def triggerMethodCallsiteResolution(method: Method): Unit - - /** Retrieve results of triggerCallsiteResolution. - */ - def getResolvedCalledMethods(callsite: CallRepr): Iterable[Method] - - /** Retrieve results of triggerMethodResolution. - */ - def getResolvedMethodCallsites(method: Method): Iterable[CallRepr] + def getUnresolvedMethodFullNames(callsite: CallRepr): Iterable[String] = + triggerCallsiteResolution(callsite) + getUnresolvedMethodFullNamesInternal(callsite) + + def getUnresolvedMethodFullNamesInternal(callsite: CallRepr): Iterable[String] + + /** Get methods called at the given callsite. This internally calls triggerCallsiteResolution. + */ + def getCalledMethods(callsite: CallRepr): Iterable[Method] = + val combined = mutable.ArrayBuffer.empty[Method] + if callsite.nonEmpty then + triggerCallsiteResolution(callsite) + callsite._callOut.foreach(method => combined.append(method.asInstanceOf[Method])) + combined.appendAll(getResolvedCalledMethods(callsite)) + combined + + /** Same as getCalledMethods but with traversal return type. + */ + def getCalledMethodsAsTraversal(callsite: CallRepr): Iterator[Method] = + getCalledMethods(callsite).iterator + + /** Get callsites of the given method. This internally calls triggerMethodResolution. + */ + def getMethodCallsites(method: Method): Iterable[CallRepr] = + triggerMethodCallsiteResolution(method) + // The same call sites of a method can be found via static and dynamic lookup. + // This is for example the case for Java virtual call sites which are statically assert + // a certain method which could be overriden. If we are looking for the call sites of + // such a statically asserted method, we find it twice and thus deduplicate here. + val combined = mutable.LinkedHashSet.empty[CallRepr] + method._callIn.foreach(call => combined.add(call.asInstanceOf[CallRepr])) + combined.addAll(getResolvedMethodCallsites(method)) + + combined.toBuffer + + /** Same as getMethodCallsites but with traversal return type. + */ + def getMethodCallsitesAsTraversal(method: Method): Iterator[CallRepr] = + getMethodCallsites(method).iterator + + /** Starts data flow tracking to find all method which could be called at the given callsite. The + * result is stored in the resolver internal cache. + */ + def triggerCallsiteResolution(callsite: CallRepr): Unit + + /** Starts data flow tracking to find all callsites which could call the given method. The result + * is stored in the resolver internal cache. + */ + def triggerMethodCallsiteResolution(method: Method): Unit + + /** Retrieve results of triggerCallsiteResolution. + */ + def getResolvedCalledMethods(callsite: CallRepr): Iterable[Method] + + /** Retrieve results of triggerMethodResolution. + */ + def getResolvedMethodCallsites(method: Method): Iterable[CallRepr] end ICallResolver object NoResolve extends ICallResolver: - def triggerCallsiteResolution(callsite: CallRepr): Unit = {} + def triggerCallsiteResolution(callsite: CallRepr): Unit = {} - def triggerMethodCallsiteResolution(method: Method): Unit = {} + def triggerMethodCallsiteResolution(method: Method): Unit = {} - override def getResolvedCalledMethods(callsite: CallRepr): Iterable[Method] = - Iterable.empty + override def getResolvedCalledMethods(callsite: CallRepr): Iterable[Method] = + Iterable.empty - override def getResolvedMethodCallsites(method: Method): Iterable[CallRepr] = - Iterable.empty + override def getResolvedMethodCallsites(method: Method): Iterable[CallRepr] = + Iterable.empty - override def getUnresolvedMethodFullNamesInternal(callsite: CallRepr): Iterable[String] = - Iterable.empty + override def getUnresolvedMethodFullNamesInternal(callsite: CallRepr): Iterable[String] = + Iterable.empty diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/LocationCreator.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/LocationCreator.scala index 1e314f9a..bdecb4b3 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/LocationCreator.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/LocationCreator.scala @@ -10,68 +10,68 @@ import scala.annotation.tailrec * all (and only) steps extending DataFlowObject should/must have `newSink`, `newSource` and `newLocation` */ object LocationCreator: - private val logger: Logger = LoggerFactory.getLogger(getClass) + private val logger: Logger = LoggerFactory.getLogger(getClass) - def apply(node: StoredNode)(implicit finder: NodeExtensionFinder): NewLocation = - try - location(node) - catch - case exc @ (_: NoSuchElementException | _: ClassCastException) => - logger.debug(s"Cannot determine location for ${node.label} due to broken CPG", exc) - emptyLocation(node.label, Some(node)) + def apply(node: StoredNode)(implicit finder: NodeExtensionFinder): NewLocation = + try + location(node) + catch + case exc @ (_: NoSuchElementException | _: ClassCastException) => + logger.debug(s"Cannot determine location for ${node.label} due to broken CPG", exc) + emptyLocation(node.label, Some(node)) - private def location(node: StoredNode)(implicit finder: NodeExtensionFinder): NewLocation = - finder(node) match - case Some(n: HasLocation) => n.location - case _ => LocationCreator.emptyLocation("", None) + private def location(node: StoredNode)(implicit finder: NodeExtensionFinder): NewLocation = + finder(node) match + case Some(n: HasLocation) => n.location + case _ => LocationCreator.emptyLocation("", None) - def apply( - node: StoredNode, - symbol: String, - label: String, - lineNumber: Option[Integer], - method: Method - ): NewLocation = - if method == null then - NewLocation().node(node) - else - val typeOption = methodToTypeDecl(method) - val typeName = typeOption.map(_.fullName).getOrElse("") - val typeShortName = typeOption.map(_.name).getOrElse("") + def apply( + node: StoredNode, + symbol: String, + label: String, + lineNumber: Option[Integer], + method: Method + ): NewLocation = + if method == null then + NewLocation().node(node) + else + val typeOption = methodToTypeDecl(method) + val typeName = typeOption.map(_.fullName).getOrElse("") + val typeShortName = typeOption.map(_.name).getOrElse("") - val namespaceOption = - for - tpe <- typeOption - namespaceBlock <- tpe.namespaceBlock - namespace <- namespaceBlock._namespaceViaRefOut.nextOption() - yield namespace.name - val namespaceName = namespaceOption.getOrElse("") + val namespaceOption = + for + tpe <- typeOption + namespaceBlock <- tpe.namespaceBlock + namespace <- namespaceBlock._namespaceViaRefOut.nextOption() + yield namespace.name + val namespaceName = namespaceOption.getOrElse("") - NewLocation() - .symbol(symbol) - .methodFullName(method.fullName) - .methodShortName(method.name) - .packageName(namespaceName) - .lineNumber(lineNumber) - .className(typeName) - .classShortName(typeShortName) - .nodeLabel(label) - .filename(if method.filename.isEmpty then "N/A" else method.filename) - .node(node) + NewLocation() + .symbol(symbol) + .methodFullName(method.fullName) + .methodShortName(method.name) + .packageName(namespaceName) + .lineNumber(lineNumber) + .className(typeName) + .classShortName(typeShortName) + .nodeLabel(label) + .filename(if method.filename.isEmpty then "N/A" else method.filename) + .node(node) - private def methodToTypeDecl(method: Method): Option[TypeDecl] = - findVertex(method, _.isInstanceOf[TypeDecl]).map(_.asInstanceOf[TypeDecl]) + private def methodToTypeDecl(method: Method): Option[TypeDecl] = + findVertex(method, _.isInstanceOf[TypeDecl]).map(_.asInstanceOf[TypeDecl]) - @tailrec - private def findVertex( - node: StoredNode, - instanceCheck: StoredNode => Boolean - ): Option[StoredNode] = - node._astIn.nextOption() match - case Some(head) if instanceCheck(head) => Some(head) - case Some(head) => findVertex(head, instanceCheck) - case None => None + @tailrec + private def findVertex( + node: StoredNode, + instanceCheck: StoredNode => Boolean + ): Option[StoredNode] = + node._astIn.nextOption() match + case Some(head) if instanceCheck(head) => Some(head) + case Some(head) => findVertex(head, instanceCheck) + case None => None - def emptyLocation(label: String, node: Option[StoredNode]): NewLocation = - NewLocation().nodeLabel(label).node(node) + def emptyLocation(label: String, node: Option[StoredNode]): NewLocation = + NewLocation().nodeLabel(label).node(node) end LocationCreator diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/NewNodeSteps.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/NewNodeSteps.scala index e83cc720..37d3e4d4 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/NewNodeSteps.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/NewNodeSteps.scala @@ -4,14 +4,14 @@ import io.shiftleft.codepropertygraph.generated.nodes.NewNode import overflowdb.BatchedUpdate.DiffGraphBuilder trait HasStoreMethod: - def store()(implicit diffBuilder: DiffGraphBuilder): Unit + def store()(implicit diffBuilder: DiffGraphBuilder): Unit class NewNodeSteps[A <: NewNode](val traversal: Iterator[A]) extends HasStoreMethod: - override def store()(implicit diffBuilder: DiffGraphBuilder): Unit = - traversal.sideEffect(storeRecursively).iterate() + override def store()(implicit diffBuilder: DiffGraphBuilder): Unit = + traversal.sideEffect(storeRecursively).iterate() - private def storeRecursively(newNode: NewNode)(implicit diffBuilder: DiffGraphBuilder): Unit = - diffBuilder.addNode(newNode) + private def storeRecursively(newNode: NewNode)(implicit diffBuilder: DiffGraphBuilder): Unit = + diffBuilder.addNode(newNode) - def label: Iterator[String] = traversal.map(_.label) + def label: Iterator[String] = traversal.map(_.label) diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/NewTagNodePairTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/NewTagNodePairTraversal.scala index be2a258f..ca1c60e6 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/NewTagNodePairTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/NewTagNodePairTraversal.scala @@ -5,14 +5,14 @@ import io.shiftleft.codepropertygraph.generated.nodes.{NewNode, NewTagNodePair, import overflowdb.BatchedUpdate.DiffGraphBuilder class NewTagNodePairTraversal(traversal: Iterator[NewTagNodePair]) extends HasStoreMethod: - override def store()(implicit diffGraph: DiffGraphBuilder): Unit = - traversal.foreach { tagNodePair => - val tag = tagNodePair.tag - val tagValue = tagNodePair.node - diffGraph.addNode(tag.asInstanceOf[NewNode]) - tagValue match - case tagValue: StoredNode => - diffGraph.addEdge(tagValue, tag.asInstanceOf[NewNode], EdgeTypes.TAGGED_BY) - case tagValue: NewNode => - diffGraph.addEdge(tagValue, tag.asInstanceOf[NewNode], EdgeTypes.TAGGED_BY, Nil) - } + override def store()(implicit diffGraph: DiffGraphBuilder): Unit = + traversal.foreach { tagNodePair => + val tag = tagNodePair.tag + val tagValue = tagNodePair.node + diffGraph.addNode(tag.asInstanceOf[NewNode]) + tagValue match + case tagValue: StoredNode => + diffGraph.addEdge(tagValue, tag.asInstanceOf[NewNode], EdgeTypes.TAGGED_BY) + case tagValue: NewNode => + diffGraph.addEdge(tagValue, tag.asInstanceOf[NewNode], EdgeTypes.TAGGED_BY, Nil) + } diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/NodeExtensionFinder.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/NodeExtensionFinder.scala index 299a0236..2c8dabc5 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/NodeExtensionFinder.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/NodeExtensionFinder.scala @@ -26,18 +26,18 @@ import io.shiftleft.semanticcpg.language.nodemethods.{ } trait NodeExtensionFinder: - def apply(n: StoredNode): Option[NodeExtension] + def apply(n: StoredNode): Option[NodeExtension] object DefaultNodeExtensionFinder extends NodeExtensionFinder: - override def apply(node: StoredNode): Option[NodeExtension] = - node match - case n: Method => Some(new MethodMethods(n)) - case n: MethodParameterIn => Some(new MethodParameterInMethods(n)) - case n: MethodParameterOut => Some(new MethodParameterOutMethods(n)) - case n: MethodReturn => Some(new MethodReturnMethods(n)) - case n: Call => Some(new CallMethods(n)) - case n: Identifier => Some(new IdentifierMethods(n)) - case n: Literal => Some(new LiteralMethods(n)) - case n: Local => Some(new LocalMethods(n)) - case n: MethodRef => Some(new MethodRefMethods(n)) - case _ => None + override def apply(node: StoredNode): Option[NodeExtension] = + node match + case n: Method => Some(new MethodMethods(n)) + case n: MethodParameterIn => Some(new MethodParameterInMethods(n)) + case n: MethodParameterOut => Some(new MethodParameterOutMethods(n)) + case n: MethodReturn => Some(new MethodReturnMethods(n)) + case n: Call => Some(new CallMethods(n)) + case n: Identifier => Some(new IdentifierMethods(n)) + case n: Literal => Some(new LiteralMethods(n)) + case n: Local => Some(new LocalMethods(n)) + case n: MethodRef => Some(new MethodRefMethods(n)) + case _ => None diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/NodeOrdering.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/NodeOrdering.scala index f0e13235..66f1bca2 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/NodeOrdering.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/NodeOrdering.scala @@ -4,46 +4,46 @@ import scala.collection.mutable object NodeOrdering: - /** For a given CFG with the entry node `cfgEntry` and an expansion function `expand`, return a - * map that associates each node with an index such that nodes are numbered in post order. - */ - def postOrderNumbering[NodeType]( - cfgEntry: NodeType, - expand: NodeType => Iterator[NodeType] - ): mutable.LinkedHashMap[NodeType, Int] = - var stack = (cfgEntry, expand(cfgEntry)) :: Nil - val visited = mutable.Set.empty[NodeType] - val numbering = mutable.LinkedHashMap.empty[NodeType, Int] - var nextNumber = 0 + /** For a given CFG with the entry node `cfgEntry` and an expansion function `expand`, return a + * map that associates each node with an index such that nodes are numbered in post order. + */ + def postOrderNumbering[NodeType]( + cfgEntry: NodeType, + expand: NodeType => Iterator[NodeType] + ): mutable.LinkedHashMap[NodeType, Int] = + var stack = (cfgEntry, expand(cfgEntry)) :: Nil + val visited = mutable.Set.empty[NodeType] + val numbering = mutable.LinkedHashMap.empty[NodeType, Int] + var nextNumber = 0 - while stack.nonEmpty do - val (node, successors) = stack.head - visited += node + while stack.nonEmpty do + val (node, successors) = stack.head + visited += node - if successors.hasNext then - val successor = successors.next() - if !visited.contains(successor) then - stack = (successor, expand(successor)) :: stack - else - stack = stack.tail - numbering.put(node, nextNumber) - nextNumber += 1 - numbering - end postOrderNumbering + if successors.hasNext then + val successor = successors.next() + if !visited.contains(successor) then + stack = (successor, expand(successor)) :: stack + else + stack = stack.tail + numbering.put(node, nextNumber) + nextNumber += 1 + numbering + end postOrderNumbering - /** For a list of (node, number) pairs, return the list of nodes obtained by sorting nodes - * according to number in reverse order. - */ - def reverseNodeList[NodeType](nodeNumberPairs: List[(NodeType, Int)]): List[NodeType] = - nodeNumberPairs - .sortBy { case (_, num) => -num } - .map { case (node, _) => node } + /** For a list of (node, number) pairs, return the list of nodes obtained by sorting nodes + * according to number in reverse order. + */ + def reverseNodeList[NodeType](nodeNumberPairs: List[(NodeType, Int)]): List[NodeType] = + nodeNumberPairs + .sortBy { case (_, num) => -num } + .map { case (node, _) => node } - /** For a list of (node, number) pairs, return the list of nodes obtained by sorting nodes - * according to number. - */ - def nodeList[NodeType](nodeNumberPairs: List[(NodeType, Int)]): List[NodeType] = - nodeNumberPairs - .sortBy { case (_, num) => num } - .map { case (node, _) => node } + /** For a list of (node, number) pairs, return the list of nodes obtained by sorting nodes + * according to number. + */ + def nodeList[NodeType](nodeNumberPairs: List[(NodeType, Int)]): List[NodeType] = + nodeNumberPairs + .sortBy { case (_, num) => num } + .map { case (node, _) => node } end NodeOrdering diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/NodeSteps.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/NodeSteps.scala index 7aac3442..a9dea552 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/NodeSteps.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/NodeSteps.scala @@ -15,29 +15,29 @@ import overflowdb.traversal.help.Doc @help.Traversal(elementType = classOf[StoredNode]) class NodeSteps[NodeType <: StoredNode](val traversal: Iterator[NodeType]) extends AnyVal: - @Doc( - info = "The source file this code is in", - longInfo = """ + @Doc( + info = "The source file this code is in", + longInfo = """ |Not all but most node in the graph can be associated with |a specific source file they appear in. `file` provides |the file node that represents that source file. |""" - ) - def file: Iterator[File] = - traversal - .choose(_.label) { - case NodeTypes.NAMESPACE => _.in(EdgeTypes.REF).out(EdgeTypes.SOURCE_FILE) - case NodeTypes.COMMENT => _.in(EdgeTypes.AST).hasLabel(NodeTypes.FILE) - case _ => - _.repeat(_.coalesce(_.out(EdgeTypes.SOURCE_FILE), _.in(EdgeTypes.AST)))(_.until( - _.hasLabel(NodeTypes.FILE) - )) - } - .cast[File] + ) + def file: Iterator[File] = + traversal + .choose(_.label) { + case NodeTypes.NAMESPACE => _.in(EdgeTypes.REF).out(EdgeTypes.SOURCE_FILE) + case NodeTypes.COMMENT => _.in(EdgeTypes.AST).hasLabel(NodeTypes.FILE) + case _ => + _.repeat(_.coalesce(_.out(EdgeTypes.SOURCE_FILE), _.in(EdgeTypes.AST)))(_.until( + _.hasLabel(NodeTypes.FILE) + )) + } + .cast[File] - @Doc( - info = "Location, including filename and line number", - longInfo = """ + @Doc( + info = "Location, including filename and line number", + longInfo = """ |Most nodes of the graph can be associated with a specific |location in code, and `location` provides this location. |The return value is an object providing, e.g., filename, @@ -46,53 +46,53 @@ class NodeSteps[NodeType <: StoredNode](val traversal: Iterator[NodeType]) exten |to the line number alone, without requiring any parsing |on the user's side. |""" - ) - def location(implicit finder: NodeExtensionFinder): Iterator[NewLocation] = - traversal.map(_.location) + ) + def location(implicit finder: NodeExtensionFinder): Iterator[NewLocation] = + traversal.map(_.location) - @Doc( - info = "Display code (with syntax highlighting)", - longInfo = """ + @Doc( + info = "Display code (with syntax highlighting)", + longInfo = """ |For methods, dump the method code. For expressions, |dump the method code along with an arrow pointing |to the expression. Uses ansi-color highlighting. |This only works for source frontends. |""" - ) - def dump(implicit finder: NodeExtensionFinder): List[String] = - _dump(highlight = true) + ) + def dump(implicit finder: NodeExtensionFinder): List[String] = + _dump(highlight = true) - @Doc( - info = "Display code (without syntax highlighting)", - longInfo = """ + @Doc( + info = "Display code (without syntax highlighting)", + longInfo = """ |For methods, dump the method code. For expressions, |dump the method code along with an arrow pointing |to the expression. No color highlighting. |""" - ) - def dumpRaw(implicit finder: NodeExtensionFinder): List[String] = - _dump(highlight = false) + ) + def dumpRaw(implicit finder: NodeExtensionFinder): List[String] = + _dump(highlight = false) - private def _dump(highlight: Boolean)(implicit finder: NodeExtensionFinder): List[String] = - // initialized on first element as we need the graph for retrieving the metaData node. - // TODO: there should be a step to retrieve the metaData node for any node - // so we could avoid instantiating a new Cpg everytime using dump - var cpg: Cpg = null - traversal.map { node => - if cpg == null then cpg = new Cpg(node.graph) - val language = cpg.metaData.language.headOption - val rootPath = cpg.metaData.root.headOption - CodeDumper.dump(node.location, language, rootPath, highlight) - }.l + private def _dump(highlight: Boolean)(implicit finder: NodeExtensionFinder): List[String] = + // initialized on first element as we need the graph for retrieving the metaData node. + // TODO: there should be a step to retrieve the metaData node for any node + // so we could avoid instantiating a new Cpg everytime using dump + var cpg: Cpg = null + traversal.map { node => + if cpg == null then cpg = new Cpg(node.graph) + val language = cpg.metaData.language.headOption + val rootPath = cpg.metaData.root.headOption + CodeDumper.dump(node.location, language, rootPath, highlight) + }.l - /* follow the incoming edges of the given type as long as possible */ - protected def walkIn(edgeType: String): Iterator[Node] = - traversal - .repeat(_.in(edgeType))(_.until(_.in(edgeType).countTrav.filter(_ == 0))) + /* follow the incoming edges of the given type as long as possible */ + protected def walkIn(edgeType: String): Iterator[Node] = + traversal + .repeat(_.in(edgeType))(_.until(_.in(edgeType).countTrav.filter(_ == 0))) - @Doc( - info = "Tag node with `tagName`", - longInfo = """ + @Doc( + info = "Tag node with `tagName`", + longInfo = """ |This method can be used to tag nodes in the graph such that |they can later be looked up easily via `cpg.tag`. Tags are |key value pairs, and they can be created with `newTagNodePair`. @@ -100,31 +100,31 @@ class NodeSteps[NodeType <: StoredNode](val traversal: Iterator[NodeType]) exten |utility method `newTagNode(key)`, which is equivalent to |`newTagNode(key, "")`. |""", - example = """.newTagNode("foo")""" - ) - def newTagNode(tagName: String): NewTagNodePairTraversal = newTagNodePair(tagName, "") + example = """.newTagNode("foo")""" + ) + def newTagNode(tagName: String): NewTagNodePairTraversal = newTagNodePair(tagName, "") - @Doc( - info = "Tag node with (`tagName`, `tagValue`)", - longInfo = "", - example = """.newTagNodePair("key","val")""" - ) - def newTagNodePair(tagName: String, tagValue: String): NewTagNodePairTraversal = - new NewTagNodePairTraversal(traversal.map { node => - NewTagNodePair() - .tag(NewTag().name(tagName).value(tagValue)) - .node(node) - }) + @Doc( + info = "Tag node with (`tagName`, `tagValue`)", + longInfo = "", + example = """.newTagNodePair("key","val")""" + ) + def newTagNodePair(tagName: String, tagValue: String): NewTagNodePairTraversal = + new NewTagNodePairTraversal(traversal.map { node => + NewTagNodePair() + .tag(NewTag().name(tagName).value(tagValue)) + .node(node) + }) - @Doc(info = "Tags attached to this node") - def tagList: List[List[Tag]] = - traversal.map { taggedNode => - taggedNode.tag.l - }.l + @Doc(info = "Tags attached to this node") + def tagList: List[List[Tag]] = + traversal.map { taggedNode => + taggedNode.tag.l + }.l - @Doc(info = "Tags attached to this node") - def tag: Iterator[Tag] = - traversal.flatMap { node => - node.tag - } + @Doc(info = "Tags attached to this node") + def tag: Iterator[Tag] = + traversal.flatMap { node => + node.tag + } end NodeSteps diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/NodeTypeStarters.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/NodeTypeStarters.scala index 3bf04173..f4392f1b 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/NodeTypeStarters.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/NodeTypeStarters.scala @@ -13,306 +13,306 @@ import scala.jdk.CollectionConverters.IteratorHasAsScala @help.TraversalSource class NodeTypeStarters(cpg: Cpg) extends TraversalSource(cpg.graph): - /** Traverse to all nodes. - */ - @Doc(info = "All nodes of the graph") - override def all: Traversal[StoredNode] = - cpg.graph.nodes.asScala.cast[StoredNode] - - /** Traverse to all annotations - */ - def annotation: Traversal[Annotation] = - InitialTraversal.from[Annotation](cpg.graph, NodeTypes.ANNOTATION) - - /** Traverse to all arguments passed to methods - */ - @Doc(info = "All arguments (actual parameters)") - def argument: Traversal[Expression] = - call.argument - - /** Shorthand for `cpg.argument.code(code)` - */ - def argument(code: String): Traversal[Expression] = - argument.code(code) - - @Doc(info = "All breaks (`ControlStructure` nodes)") - def break: Traversal[ControlStructure] = - controlStructure.isBreak - - /** Traverse to all call sites - */ - @Doc(info = "All call sites") - def call: Traversal[Call] = - InitialTraversal.from[Call](cpg.graph, NodeTypes.CALL) - - /** Shorthand for `cpg.call.name(name)` - */ - def call(name: String): Traversal[Call] = - call.name(name) - - /** Traverse to all comments in source-based CPGs. - */ - @Doc(info = "All comments in source-based CPGs") - def comment: Traversal[Comment] = - InitialTraversal.from[Comment](cpg.graph, NodeTypes.COMMENT) - - /** Shorthand for `cpg.comment.code(code)` - */ - def comment(code: String): Traversal[Comment] = - comment.has(Properties.CODE -> code) - - /** Traverse to all config files - */ - @Doc(info = "All config files") - def configFile: Traversal[ConfigFile] = - InitialTraversal.from[ConfigFile](cpg.graph, NodeTypes.CONFIG_FILE) - - /** Shorthand for `cpg.configFile.name(name)` - */ - def configFile(name: String): Traversal[ConfigFile] = - configFile.name(name) - - /** Traverse to all dependencies - */ - @Doc(info = "All dependencies") - def dependency: Traversal[Dependency] = - InitialTraversal.from[Dependency](cpg.graph, NodeTypes.DEPENDENCY) - - /** Shorthand for `cpg.dependency.name(name)` - */ - def dependency(name: String): Traversal[Dependency] = - dependency.name(name) - - @Doc(info = "All control structures (source-based frontends)") - def controlStructure: Traversal[ControlStructure] = - InitialTraversal.from[ControlStructure](cpg.graph, NodeTypes.CONTROL_STRUCTURE) - - @Doc(info = "All continues (`ControlStructure` nodes)") - def continue: Traversal[ControlStructure] = - controlStructure.isContinue - - @Doc(info = "All do blocks (`ControlStructure` nodes)") - def doBlock: Traversal[ControlStructure] = - controlStructure.isDo - - @Doc(info = "All else blocks (`ControlStructure` nodes)") - def elseBlock: Traversal[ControlStructure] = - controlStructure.isElse - - @Doc(info = "All throws (`ControlStructure` nodes)") - def throws: Traversal[ControlStructure] = - controlStructure.isThrow - - /** Traverse to all source files - */ - @Doc(info = "All source files") - def file: Traversal[File] = - InitialTraversal.from[File](cpg.graph, NodeTypes.FILE) - - /** Shorthand for `cpg.file.name(name)` - */ - def file(name: String): Traversal[File] = - file.name(name) - - @Doc(info = "All for blocks (`ControlStructure` nodes)") - def forBlock: Traversal[ControlStructure] = - controlStructure.isFor - - @Doc(info = "All gotos (`ControlStructure` nodes)") - def goto: Traversal[ControlStructure] = - controlStructure.isGoto - - /** Traverse to all identifiers, e.g., occurrences of local variables or class members in method - * bodies. - */ - @Doc(info = "All identifier usages") - def identifier: Traversal[Identifier] = - InitialTraversal.from[Identifier](cpg.graph, NodeTypes.IDENTIFIER) - - /** Shorthand for `cpg.identifier.name(name)` - */ - def identifier(name: String): Traversal[Identifier] = - identifier.name(name) - - @Doc(info = "All if blocks (`ControlStructure` nodes)") - def ifBlock: Traversal[ControlStructure] = - controlStructure.isIf - - /** Traverse to all jump targets - */ - @Doc(info = "All jump targets, i.e., labels") - def jumpTarget: Traversal[JumpTarget] = - InitialTraversal.from[JumpTarget](cpg.graph, NodeTypes.JUMP_TARGET) - - /** Traverse to all local variable declarations - */ - @Doc(info = "All local variables") - def local: Traversal[Local] = - InitialTraversal.from[Local](cpg.graph, NodeTypes.LOCAL) - - /** Shorthand for `cpg.local.name` - */ - def local(name: String): Traversal[Local] = - local.name(name) - - /** Traverse to all literals (constant strings and numbers provided directly in the code). - */ - @Doc(info = "All literals, e.g., numbers or strings") - def literal: Traversal[Literal] = - InitialTraversal.from[Literal](cpg.graph, NodeTypes.LITERAL) - - /** Shorthand for `cpg.literal.code(code)` - */ - def literal(code: String): Traversal[Literal] = - literal.code(code) - - /** Traverse to all methods - */ - @Doc(info = "All methods") - def method: Traversal[Method] = - InitialTraversal.from[Method](cpg.graph, NodeTypes.METHOD) - - /** Shorthand for `cpg.method.name(name)` - */ - @Doc(info = "All methods with a name that matches the given pattern") - def method(namePattern: String): Traversal[Method] = - method.name(namePattern) - - /** Traverse to all formal return parameters - */ - @Doc(info = "All formal return parameters") - def methodReturn: Traversal[MethodReturn] = - InitialTraversal.from[MethodReturn](cpg.graph, NodeTypes.METHOD_RETURN) - - /** Traverse to all class members - */ - @Doc(info = "All members of complex types (e.g., classes/structures)") - def member: Traversal[Member] = - InitialTraversal.from[Member](cpg.graph, NodeTypes.MEMBER) - - /** Shorthand for `cpg.member.name(name)` - */ - def member(name: String): Traversal[Member] = - member.name(name) - - /** Traverse to all meta data entries - */ - @Doc(info = "Meta data blocks for graph") - def metaData: Traversal[MetaData] = - InitialTraversal.from[MetaData](cpg.graph, NodeTypes.META_DATA) - - /** Traverse to all method references - */ - @Doc(info = "All method references") - def methodRef: Traversal[MethodRef] = - InitialTraversal.from[MethodRef](cpg.graph, NodeTypes.METHOD_REF) - - /** Shorthand for `cpg.methodRef.filter(_.referencedMethod.name(name))` - */ - def methodRef(name: String): Traversal[MethodRef] = - methodRef.where(_.referencedMethod.name(name)) - - /** Traverse to all namespaces, e.g., packages in Java. - */ - @Doc(info = "All namespaces") - def namespace: Traversal[Namespace] = - InitialTraversal.from[Namespace](cpg.graph, NodeTypes.NAMESPACE) - - /** Shorthand for `cpg.namespace.name(name)` - */ - def namespace(name: String): Traversal[Namespace] = - namespace.name(name) - - /** Traverse to all namespace blocks, e.g., packages in Java. - */ - def namespaceBlock: Traversal[NamespaceBlock] = - InitialTraversal.from[NamespaceBlock](cpg.graph, NodeTypes.NAMESPACE_BLOCK) - - /** Shorthand for `cpg.namespaceBlock.name(name)` - */ - def namespaceBlock(name: String): Traversal[NamespaceBlock] = - namespaceBlock.name(name) - - /** Traverse to all input parameters - */ - @Doc(info = "All parameters") - def parameter: Traversal[MethodParameterIn] = - InitialTraversal.from[MethodParameterIn](cpg.graph, NodeTypes.METHOD_PARAMETER_IN) - - /** Shorthand for `cpg.parameter.name(name)` - */ - def parameter(name: String): Traversal[MethodParameterIn] = - parameter.name(name) - - /** Traverse to all return expressions - */ - @Doc(info = "All actual return parameters") - def ret: Traversal[Return] = - InitialTraversal.from[Return](cpg.graph, NodeTypes.RETURN) - - /** Shorthand for `returns.code(code)` - */ - def ret(code: String): Traversal[Return] = - ret.code(code) - - @Doc(info = "All imports") - def imports: Traversal[Import] = - InitialTraversal.from[Import](cpg.graph, NodeTypes.IMPORT) - - @Doc(info = "All switch blocks (`ControlStructure` nodes)") - def switchBlock: Traversal[ControlStructure] = - controlStructure.isSwitch - - @Doc(info = "All try blocks (`ControlStructure` nodes)") - def tryBlock: Traversal[ControlStructure] = - controlStructure.isTry - - /** Traverse to all types, e.g., Set - */ - @Doc(info = "All used types") - def typ: Traversal[Type] = - InitialTraversal.from[Type](cpg.graph, NodeTypes.TYPE) - - /** Shorthand for `cpg.typ.name(name)` - */ - @Doc(info = "All used types with given name") - def typ(name: String): Traversal[Type] = - typ.name(name) - - /** Traverse to all declarations, e.g., Set - */ - @Doc(info = "All declarations of types") - def typeDecl: Traversal[TypeDecl] = - InitialTraversal.from[TypeDecl](cpg.graph, NodeTypes.TYPE_DECL) - - /** Shorthand for cpg.typeDecl.name(name) - */ - def typeDecl(name: String): Traversal[TypeDecl] = - typeDecl.name(name) - - /** Traverse to all tags - */ - @Doc(info = "All tags") - def tag: Traversal[Tag] = - InitialTraversal.from[Tag](cpg.graph, NodeTypes.TAG) - - @Doc(info = "All tags with given name") - def tag(name: String): Traversal[Tag] = - tag.name(name) - - /** Traverse to all template DOM nodes - */ - @Doc(info = "All template DOM nodes") - def templateDom: Traversal[TemplateDom] = - InitialTraversal.from[TemplateDom](cpg.graph, NodeTypes.TEMPLATE_DOM) - - /** Traverse to all type references - */ - @Doc(info = "All type references") - def typeRef: Traversal[TypeRef] = - InitialTraversal.from[TypeRef](cpg.graph, NodeTypes.TYPE_REF) - - @Doc(info = "All while blocks (`ControlStructure` nodes)") - def whileBlock: Traversal[ControlStructure] = - controlStructure.isWhile + /** Traverse to all nodes. + */ + @Doc(info = "All nodes of the graph") + override def all: Traversal[StoredNode] = + cpg.graph.nodes.asScala.cast[StoredNode] + + /** Traverse to all annotations + */ + def annotation: Traversal[Annotation] = + InitialTraversal.from[Annotation](cpg.graph, NodeTypes.ANNOTATION) + + /** Traverse to all arguments passed to methods + */ + @Doc(info = "All arguments (actual parameters)") + def argument: Traversal[Expression] = + call.argument + + /** Shorthand for `cpg.argument.code(code)` + */ + def argument(code: String): Traversal[Expression] = + argument.code(code) + + @Doc(info = "All breaks (`ControlStructure` nodes)") + def break: Traversal[ControlStructure] = + controlStructure.isBreak + + /** Traverse to all call sites + */ + @Doc(info = "All call sites") + def call: Traversal[Call] = + InitialTraversal.from[Call](cpg.graph, NodeTypes.CALL) + + /** Shorthand for `cpg.call.name(name)` + */ + def call(name: String): Traversal[Call] = + call.name(name) + + /** Traverse to all comments in source-based CPGs. + */ + @Doc(info = "All comments in source-based CPGs") + def comment: Traversal[Comment] = + InitialTraversal.from[Comment](cpg.graph, NodeTypes.COMMENT) + + /** Shorthand for `cpg.comment.code(code)` + */ + def comment(code: String): Traversal[Comment] = + comment.has(Properties.CODE -> code) + + /** Traverse to all config files + */ + @Doc(info = "All config files") + def configFile: Traversal[ConfigFile] = + InitialTraversal.from[ConfigFile](cpg.graph, NodeTypes.CONFIG_FILE) + + /** Shorthand for `cpg.configFile.name(name)` + */ + def configFile(name: String): Traversal[ConfigFile] = + configFile.name(name) + + /** Traverse to all dependencies + */ + @Doc(info = "All dependencies") + def dependency: Traversal[Dependency] = + InitialTraversal.from[Dependency](cpg.graph, NodeTypes.DEPENDENCY) + + /** Shorthand for `cpg.dependency.name(name)` + */ + def dependency(name: String): Traversal[Dependency] = + dependency.name(name) + + @Doc(info = "All control structures (source-based frontends)") + def controlStructure: Traversal[ControlStructure] = + InitialTraversal.from[ControlStructure](cpg.graph, NodeTypes.CONTROL_STRUCTURE) + + @Doc(info = "All continues (`ControlStructure` nodes)") + def continue: Traversal[ControlStructure] = + controlStructure.isContinue + + @Doc(info = "All do blocks (`ControlStructure` nodes)") + def doBlock: Traversal[ControlStructure] = + controlStructure.isDo + + @Doc(info = "All else blocks (`ControlStructure` nodes)") + def elseBlock: Traversal[ControlStructure] = + controlStructure.isElse + + @Doc(info = "All throws (`ControlStructure` nodes)") + def throws: Traversal[ControlStructure] = + controlStructure.isThrow + + /** Traverse to all source files + */ + @Doc(info = "All source files") + def file: Traversal[File] = + InitialTraversal.from[File](cpg.graph, NodeTypes.FILE) + + /** Shorthand for `cpg.file.name(name)` + */ + def file(name: String): Traversal[File] = + file.name(name) + + @Doc(info = "All for blocks (`ControlStructure` nodes)") + def forBlock: Traversal[ControlStructure] = + controlStructure.isFor + + @Doc(info = "All gotos (`ControlStructure` nodes)") + def goto: Traversal[ControlStructure] = + controlStructure.isGoto + + /** Traverse to all identifiers, e.g., occurrences of local variables or class members in method + * bodies. + */ + @Doc(info = "All identifier usages") + def identifier: Traversal[Identifier] = + InitialTraversal.from[Identifier](cpg.graph, NodeTypes.IDENTIFIER) + + /** Shorthand for `cpg.identifier.name(name)` + */ + def identifier(name: String): Traversal[Identifier] = + identifier.name(name) + + @Doc(info = "All if blocks (`ControlStructure` nodes)") + def ifBlock: Traversal[ControlStructure] = + controlStructure.isIf + + /** Traverse to all jump targets + */ + @Doc(info = "All jump targets, i.e., labels") + def jumpTarget: Traversal[JumpTarget] = + InitialTraversal.from[JumpTarget](cpg.graph, NodeTypes.JUMP_TARGET) + + /** Traverse to all local variable declarations + */ + @Doc(info = "All local variables") + def local: Traversal[Local] = + InitialTraversal.from[Local](cpg.graph, NodeTypes.LOCAL) + + /** Shorthand for `cpg.local.name` + */ + def local(name: String): Traversal[Local] = + local.name(name) + + /** Traverse to all literals (constant strings and numbers provided directly in the code). + */ + @Doc(info = "All literals, e.g., numbers or strings") + def literal: Traversal[Literal] = + InitialTraversal.from[Literal](cpg.graph, NodeTypes.LITERAL) + + /** Shorthand for `cpg.literal.code(code)` + */ + def literal(code: String): Traversal[Literal] = + literal.code(code) + + /** Traverse to all methods + */ + @Doc(info = "All methods") + def method: Traversal[Method] = + InitialTraversal.from[Method](cpg.graph, NodeTypes.METHOD) + + /** Shorthand for `cpg.method.name(name)` + */ + @Doc(info = "All methods with a name that matches the given pattern") + def method(namePattern: String): Traversal[Method] = + method.name(namePattern) + + /** Traverse to all formal return parameters + */ + @Doc(info = "All formal return parameters") + def methodReturn: Traversal[MethodReturn] = + InitialTraversal.from[MethodReturn](cpg.graph, NodeTypes.METHOD_RETURN) + + /** Traverse to all class members + */ + @Doc(info = "All members of complex types (e.g., classes/structures)") + def member: Traversal[Member] = + InitialTraversal.from[Member](cpg.graph, NodeTypes.MEMBER) + + /** Shorthand for `cpg.member.name(name)` + */ + def member(name: String): Traversal[Member] = + member.name(name) + + /** Traverse to all meta data entries + */ + @Doc(info = "Meta data blocks for graph") + def metaData: Traversal[MetaData] = + InitialTraversal.from[MetaData](cpg.graph, NodeTypes.META_DATA) + + /** Traverse to all method references + */ + @Doc(info = "All method references") + def methodRef: Traversal[MethodRef] = + InitialTraversal.from[MethodRef](cpg.graph, NodeTypes.METHOD_REF) + + /** Shorthand for `cpg.methodRef.filter(_.referencedMethod.name(name))` + */ + def methodRef(name: String): Traversal[MethodRef] = + methodRef.where(_.referencedMethod.name(name)) + + /** Traverse to all namespaces, e.g., packages in Java. + */ + @Doc(info = "All namespaces") + def namespace: Traversal[Namespace] = + InitialTraversal.from[Namespace](cpg.graph, NodeTypes.NAMESPACE) + + /** Shorthand for `cpg.namespace.name(name)` + */ + def namespace(name: String): Traversal[Namespace] = + namespace.name(name) + + /** Traverse to all namespace blocks, e.g., packages in Java. + */ + def namespaceBlock: Traversal[NamespaceBlock] = + InitialTraversal.from[NamespaceBlock](cpg.graph, NodeTypes.NAMESPACE_BLOCK) + + /** Shorthand for `cpg.namespaceBlock.name(name)` + */ + def namespaceBlock(name: String): Traversal[NamespaceBlock] = + namespaceBlock.name(name) + + /** Traverse to all input parameters + */ + @Doc(info = "All parameters") + def parameter: Traversal[MethodParameterIn] = + InitialTraversal.from[MethodParameterIn](cpg.graph, NodeTypes.METHOD_PARAMETER_IN) + + /** Shorthand for `cpg.parameter.name(name)` + */ + def parameter(name: String): Traversal[MethodParameterIn] = + parameter.name(name) + + /** Traverse to all return expressions + */ + @Doc(info = "All actual return parameters") + def ret: Traversal[Return] = + InitialTraversal.from[Return](cpg.graph, NodeTypes.RETURN) + + /** Shorthand for `returns.code(code)` + */ + def ret(code: String): Traversal[Return] = + ret.code(code) + + @Doc(info = "All imports") + def imports: Traversal[Import] = + InitialTraversal.from[Import](cpg.graph, NodeTypes.IMPORT) + + @Doc(info = "All switch blocks (`ControlStructure` nodes)") + def switchBlock: Traversal[ControlStructure] = + controlStructure.isSwitch + + @Doc(info = "All try blocks (`ControlStructure` nodes)") + def tryBlock: Traversal[ControlStructure] = + controlStructure.isTry + + /** Traverse to all types, e.g., Set + */ + @Doc(info = "All used types") + def typ: Traversal[Type] = + InitialTraversal.from[Type](cpg.graph, NodeTypes.TYPE) + + /** Shorthand for `cpg.typ.name(name)` + */ + @Doc(info = "All used types with given name") + def typ(name: String): Traversal[Type] = + typ.name(name) + + /** Traverse to all declarations, e.g., Set + */ + @Doc(info = "All declarations of types") + def typeDecl: Traversal[TypeDecl] = + InitialTraversal.from[TypeDecl](cpg.graph, NodeTypes.TYPE_DECL) + + /** Shorthand for cpg.typeDecl.name(name) + */ + def typeDecl(name: String): Traversal[TypeDecl] = + typeDecl.name(name) + + /** Traverse to all tags + */ + @Doc(info = "All tags") + def tag: Traversal[Tag] = + InitialTraversal.from[Tag](cpg.graph, NodeTypes.TAG) + + @Doc(info = "All tags with given name") + def tag(name: String): Traversal[Tag] = + tag.name(name) + + /** Traverse to all template DOM nodes + */ + @Doc(info = "All template DOM nodes") + def templateDom: Traversal[TemplateDom] = + InitialTraversal.from[TemplateDom](cpg.graph, NodeTypes.TEMPLATE_DOM) + + /** Traverse to all type references + */ + @Doc(info = "All type references") + def typeRef: Traversal[TypeRef] = + InitialTraversal.from[TypeRef](cpg.graph, NodeTypes.TYPE_REF) + + @Doc(info = "All while blocks (`ControlStructure` nodes)") + def whileBlock: Traversal[ControlStructure] = + controlStructure.isWhile end NodeTypeStarters diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/Show.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/Show.scala index 30a2fec7..7195b7dc 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/Show.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/Show.scala @@ -8,30 +8,30 @@ import scala.jdk.CollectionConverters.* /** Typeclass for (pretty) printing an object */ trait Show[A]: - def apply(a: A): String + def apply(a: A): String object Show: - def default[A]: Show[A] = Default.asInstanceOf[Show[A]] - - private val Default = new Show[Any]: - override def apply(a: Any): String = a match - case node: NewNode => - val label = node.label - val properties = propsToString(node.properties.toList) - s"($label): $properties" - - case node: Node => - val label = node.label - val id = node.id().toString - val properties = propsToString(node.propertiesMap.asScala.toList) - s"($label,$id): $properties" - - case other => other.toString - - private def propsToString(keyValues: List[(String, Any)]): String = - keyValues - .filter(_._2.toString.nonEmpty) - .sortBy(_._1) - .map { case (key, value) => s"$key: $value" } - .mkString(", ") + def default[A]: Show[A] = Default.asInstanceOf[Show[A]] + + private val Default = new Show[Any]: + override def apply(a: Any): String = a match + case node: NewNode => + val label = node.label + val properties = propsToString(node.properties.toList) + s"($label): $properties" + + case node: Node => + val label = node.label + val id = node.id().toString + val properties = propsToString(node.propertiesMap.asScala.toList) + s"($label,$id): $properties" + + case other => other.toString + + private def propsToString(keyValues: List[(String, Any)]): String = + keyValues + .filter(_._2.toString.nonEmpty) + .sortBy(_._1) + .map { case (key, value) => s"$key: $value" } + .mkString(", ") end Show diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/Steps.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/Steps.scala index fa7f5c17..c40ef4df 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/Steps.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/Steps.scala @@ -23,93 +23,93 @@ import java.nio.file.Files */ class Steps[A](val traversal: Iterator[A]) extends AnyVal: - /** Execute the traversal and convert it to a mutable buffer - */ - def toBuffer(): mutable.Buffer[A] = traversal.to(mutable.Buffer) - - /** Shorthand for `toBuffer` - */ - def b: mutable.Buffer[A] = toBuffer() - - /** Alias for `toList` - * @deprecated - */ - def exec(): List[A] = traversal.toList - - /** Execute the travel and convert it to a Java stream. - */ - def toStream(): LazyList[A] = traversal.to(LazyList) - - /** Alias for `toStream` - */ - def s: LazyList[A] = toStream() - - /** Execute the traversal and convert it into a Java list (as opposed to the Scala list obtained - * via `toList`) - */ - def jl: JList[A] = b.asJava - - /** Execute this traversal and pretty print the results. This may mean that not all properties - * of the node are displayed or that some properties have undergone transformations to improve - * display. A good example is flow pretty-printing. This is the only three of the methods which - * we may modify on a per-node-type basis, typically via implicits of type Show[NodeType]. - */ - @Doc(info = "execute this traversal and pretty print the results") - def p(implicit show: Show[A] = Show.default): List[String] = - traversal.toList.map(show.apply) - - @Doc(info = "execute this traversal and print tabular result") - def t(implicit show: Show[A] = Show.default): Unit = - traversal.toList.map(show.apply) - - @Doc(info = "execute this traversal and show the pretty-printed results in `less`") - // uses scala-repl-pp's `#|^` operator which let's `less` inherit stdin and stdout - def browse: Unit = - given Colors = Colors.Default - traversal #|^ "less" - - /** Execute traversal and convert the result to json. `toJson` (export) contains the exact same - * information as `toList`, only in json format. Typically, the user will call this method upon - * inspection of the results of `toList` in order to export the data for processing with other - * tools. - */ - @Doc(info = "execute traversal and convert the result to json") - def toJson: String = toJson(pretty = false) - - /** Execute traversal and convert the result to pretty json. */ - @Doc(info = "execute traversal and convert the result to pretty json") - def toJsonPretty: String = toJson(pretty = true) - - protected def toJson(pretty: Boolean): String = - implicit val formats: Formats = org.json4s.DefaultFormats + Steps.nodeSerializer - - val results = traversal.toList - if pretty then writePretty(results) - else write(results) - - private def pyJson = py.module("json") - @Doc(info = "execute traversal and convert the result to python object") - def toPy: me.shadaj.scalapy.py.Dynamic = pyJson.loads(toJson(false)) - - def pyg = - val tmpDir = Files.createTempDirectory("pyg-gml-export").toFile.getAbsolutePath - traversal match - case methods: Iterator[Method] => - val exportResult = methods.gml(tmpDir) - exportResult.files.map(Torch.to_pyg) + /** Execute the traversal and convert it to a mutable buffer + */ + def toBuffer(): mutable.Buffer[A] = traversal.to(mutable.Buffer) + + /** Shorthand for `toBuffer` + */ + def b: mutable.Buffer[A] = toBuffer() + + /** Alias for `toList` + * @deprecated + */ + def exec(): List[A] = traversal.toList + + /** Execute the travel and convert it to a Java stream. + */ + def toStream(): LazyList[A] = traversal.to(LazyList) + + /** Alias for `toStream` + */ + def s: LazyList[A] = toStream() + + /** Execute the traversal and convert it into a Java list (as opposed to the Scala list obtained + * via `toList`) + */ + def jl: JList[A] = b.asJava + + /** Execute this traversal and pretty print the results. This may mean that not all properties of + * the node are displayed or that some properties have undergone transformations to improve + * display. A good example is flow pretty-printing. This is the only three of the methods which + * we may modify on a per-node-type basis, typically via implicits of type Show[NodeType]. + */ + @Doc(info = "execute this traversal and pretty print the results") + def p(implicit show: Show[A] = Show.default): List[String] = + traversal.toList.map(show.apply) + + @Doc(info = "execute this traversal and print tabular result") + def t(implicit show: Show[A] = Show.default): Unit = + traversal.toList.map(show.apply) + + @Doc(info = "execute this traversal and show the pretty-printed results in `less`") + // uses scala-repl-pp's `#|^` operator which let's `less` inherit stdin and stdout + def browse: Unit = + given Colors = Colors.Default + traversal #|^ "less" + + /** Execute traversal and convert the result to json. `toJson` (export) contains the exact same + * information as `toList`, only in json format. Typically, the user will call this method upon + * inspection of the results of `toList` in order to export the data for processing with other + * tools. + */ + @Doc(info = "execute traversal and convert the result to json") + def toJson: String = toJson(pretty = false) + + /** Execute traversal and convert the result to pretty json. */ + @Doc(info = "execute traversal and convert the result to pretty json") + def toJsonPretty: String = toJson(pretty = true) + + protected def toJson(pretty: Boolean): String = + implicit val formats: Formats = org.json4s.DefaultFormats + Steps.nodeSerializer + + val results = traversal.toList + if pretty then writePretty(results) + else write(results) + + private def pyJson = py.module("json") + @Doc(info = "execute traversal and convert the result to python object") + def toPy: me.shadaj.scalapy.py.Dynamic = pyJson.loads(toJson(false)) + + def pyg = + val tmpDir = Files.createTempDirectory("pyg-gml-export").toFile.getAbsolutePath + traversal match + case methods: Iterator[Method] => + val exportResult = methods.gml(tmpDir) + exportResult.files.map(Torch.to_pyg) end Steps object Steps: - private lazy val nodeSerializer = new CustomSerializer[AbstractNode](implicit format => - ( - { case _ => ??? }, - { case node: AbstractNode with Product => - val elementMap = (0 until node.productArity).map { i => - val label = node.productElementName(i) - val element = node.productElement(i) - label -> element - }.toMap + ("_label" -> node.label) - Extraction.decompose(elementMap) - } - ) - ) + private lazy val nodeSerializer = new CustomSerializer[AbstractNode](implicit format => + ( + { case _ => ??? }, + { case node: AbstractNode with Product => + val elementMap = (0 until node.productArity).map { i => + val label = node.productElementName(i) + val element = node.productElement(i) + label -> element + }.toMap + ("_label" -> node.label) + Extraction.decompose(elementMap) + } + ) + ) diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/TagTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/TagTraversal.scala index b4e088db..7a80dae6 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/TagTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/TagTraversal.scala @@ -6,16 +6,16 @@ import scala.reflect.ClassTag class TagTraversal(val traversal: Iterator[Tag]) extends AnyVal: - def member: Iterator[Member] = tagged[Member] - def method: Iterator[Method] = tagged[Method] - def methodReturn: Iterator[MethodReturn] = tagged[MethodReturn] - def parameter: Iterator[MethodParameterIn] = tagged[MethodParameterIn] - def parameterOut: Iterator[MethodParameterOut] = tagged[MethodParameterOut] - def call: Iterator[Call] = tagged[Call] - def identifier: Iterator[Identifier] = tagged[Identifier] - def literal: Iterator[Literal] = tagged[Literal] - def local: Iterator[Local] = tagged[Local] - def file: Iterator[File] = tagged[File] + def member: Iterator[Member] = tagged[Member] + def method: Iterator[Method] = tagged[Method] + def methodReturn: Iterator[MethodReturn] = tagged[MethodReturn] + def parameter: Iterator[MethodParameterIn] = tagged[MethodParameterIn] + def parameterOut: Iterator[MethodParameterOut] = tagged[MethodParameterOut] + def call: Iterator[Call] = tagged[Call] + def identifier: Iterator[Identifier] = tagged[Identifier] + def literal: Iterator[Literal] = tagged[Literal] + def local: Iterator[Local] = tagged[Local] + def file: Iterator[File] = tagged[File] - private def tagged[A <: StoredNode: ClassTag]: Iterator[A] = - traversal._taggedByIn.collectAll[A].sortBy(_.id).iterator + private def tagged[A <: StoredNode: ClassTag]: Iterator[A] = + traversal._taggedByIn.collectAll[A].sortBy(_.id).iterator diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/android/ConfigFileTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/android/ConfigFileTraversal.scala index fe8f3722..bad5f038 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/android/ConfigFileTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/android/ConfigFileTraversal.scala @@ -4,98 +4,98 @@ import io.shiftleft.semanticcpg.utils.SecureXmlParsing import io.shiftleft.codepropertygraph.generated.nodes class ConfigFileTraversal(val traversal: Iterator[nodes.ConfigFile]) extends AnyVal: - def usesCleartextTraffic = - traversal - .filter(_.name.endsWith(Constants.androidManifestXml)) - .map(_.content) - .flatMap(SecureXmlParsing.parseXml) - .filter(_.label == "manifest") - .flatMap(_.child) - .filter(_.label == "application") - .flatMap { applicationNode => - val activityName = - applicationNode.attribute(Constants.androidUri, "usesCleartextTraffic") - activityName.map(_.toString == "true") - } + def usesCleartextTraffic = + traversal + .filter(_.name.endsWith(Constants.androidManifestXml)) + .map(_.content) + .flatMap(SecureXmlParsing.parseXml) + .filter(_.label == "manifest") + .flatMap(_.child) + .filter(_.label == "application") + .flatMap { applicationNode => + val activityName = + applicationNode.attribute(Constants.androidUri, "usesCleartextTraffic") + activityName.map(_.toString == "true") + } - def hasReadExternalStoragePermission = - traversal - .filter(_.name.endsWith(Constants.androidManifestXml)) - .map(_.content) - .flatMap(SecureXmlParsing.parseXml) - .filter(_.label == "manifest") - .flatMap(_.child) - .filter(_.label == "uses-permission") - .flatMap { applicationNode => - val activityName = applicationNode.attribute(Constants.androidUri, "name") - activityName match - case Some(n) if n.toString == "android.permission.READ_EXTERNAL_STORAGE" => - Some(true) - case _ => None - } + def hasReadExternalStoragePermission = + traversal + .filter(_.name.endsWith(Constants.androidManifestXml)) + .map(_.content) + .flatMap(SecureXmlParsing.parseXml) + .filter(_.label == "manifest") + .flatMap(_.child) + .filter(_.label == "uses-permission") + .flatMap { applicationNode => + val activityName = applicationNode.attribute(Constants.androidUri, "name") + activityName match + case Some(n) if n.toString == "android.permission.READ_EXTERNAL_STORAGE" => + Some(true) + case _ => None + } - def exportedAndroidActivityNames = - traversal - .filter(_.name.endsWith(Constants.androidManifestXml)) - .map(_.content) - .flatMap(SecureXmlParsing.parseXml) - .filter(_.label == "manifest") - .flatMap(_.child) - .filter(_.label == "application") - .flatMap(_.child) - .filter(_.label == "activity") - .flatMap { activityNode => - /* + def exportedAndroidActivityNames = + traversal + .filter(_.name.endsWith(Constants.androidManifestXml)) + .map(_.content) + .flatMap(SecureXmlParsing.parseXml) + .filter(_.label == "manifest") + .flatMap(_.child) + .filter(_.label == "application") + .flatMap(_.child) + .filter(_.label == "activity") + .flatMap { activityNode => + /* from: https://developer.android.com/guide/components/intents-filters Note: To receive implicit intents, you must include the CATEGORY_DEFAULT category in the intent filter. The methods startActivity() and startActivityForResult() treat all intents as if they declared the CATEGORY_DEFAULT category. If you do not declare this category in your intent filter, no implicit intents will resolve to your activity. - */ - val hasIntentFilterWithDefaultCategory = - activityNode - .flatMap(_.child) - .filter(_.label == "intent-filter") - .flatMap(_.child) - .filter(_.label == "category") - .exists { node => - val categoryName = node.attribute(Constants.androidUri, "name") - categoryName match - case Some(n) => n.toString == "android.intent.category.DEFAULT" - case None => false - } - if hasIntentFilterWithDefaultCategory then - val activityName = activityNode.attribute(Constants.androidUri, "name") - activityName match - case Some(n) => Some(n.toString) - case None => None - else None - } + */ + val hasIntentFilterWithDefaultCategory = + activityNode + .flatMap(_.child) + .filter(_.label == "intent-filter") + .flatMap(_.child) + .filter(_.label == "category") + .exists { node => + val categoryName = node.attribute(Constants.androidUri, "name") + categoryName match + case Some(n) => n.toString == "android.intent.category.DEFAULT" + case None => false + } + if hasIntentFilterWithDefaultCategory then + val activityName = activityNode.attribute(Constants.androidUri, "name") + activityName match + case Some(n) => Some(n.toString) + case None => None + else None + } - def exportedBroadcastReceiverNames = - traversal - .filter(_.name.endsWith(Constants.androidManifestXml)) - .map(_.content) - .flatMap(SecureXmlParsing.parseXml) - .filter(_.label == "manifest") - .flatMap(_.child) - .filter(_.label == "application") - .flatMap(_.child) - .filter(_.label == "receiver") - .flatMap { receiverNode => - val hasIntentFilter = - receiverNode.flatMap(_.child).filter(_.label == "intent-filter").nonEmpty - if hasIntentFilter then - val isExported = receiverNode.attribute(Constants.androidUri, "exported") - isExported match - case Some(n) if n.toString == "true" => Some(receiverNode) - case _ => None - else None - } - .flatMap { node => - val name = node.attribute(Constants.androidUri, "name") - name match - case Some(n) => Some(n.toString().stripPrefix(".")) - case _ => None - } + def exportedBroadcastReceiverNames = + traversal + .filter(_.name.endsWith(Constants.androidManifestXml)) + .map(_.content) + .flatMap(SecureXmlParsing.parseXml) + .filter(_.label == "manifest") + .flatMap(_.child) + .filter(_.label == "application") + .flatMap(_.child) + .filter(_.label == "receiver") + .flatMap { receiverNode => + val hasIntentFilter = + receiverNode.flatMap(_.child).filter(_.label == "intent-filter").nonEmpty + if hasIntentFilter then + val isExported = receiverNode.attribute(Constants.androidUri, "exported") + isExported match + case Some(n) if n.toString == "true" => Some(receiverNode) + case _ => None + else None + } + .flatMap { node => + val name = node.attribute(Constants.androidUri, "name") + name match + case Some(n) => Some(n.toString().stripPrefix(".")) + case _ => None + } end ConfigFileTraversal diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/android/Constants.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/android/Constants.scala index cef6d9bc..6301cd2b 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/android/Constants.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/android/Constants.scala @@ -1,5 +1,5 @@ package io.shiftleft.semanticcpg.language.android object Constants: - val androidUri = "http://schemas.android.com/apk/res/android" - val androidManifestXml = "AndroidManifest.xml" + val androidUri = "http://schemas.android.com/apk/res/android" + val androidManifestXml = "AndroidManifest.xml" diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/android/LocalTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/android/LocalTraversal.scala index 7cbf0fe4..21f66fa5 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/android/LocalTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/android/LocalTraversal.scala @@ -4,23 +4,23 @@ import io.shiftleft.codepropertygraph.generated.nodes.Local import io.shiftleft.semanticcpg.language.* class LocalTraversal(val traversal: Iterator[Local]) extends AnyVal: - def callsEnableJS = - traversal - .where( - _.referencingIdentifiers.inCall - .nameExact("getSettings") - .where( - _.inCall - .nameExact("setJavaScriptEnabled") - .argument - .isLiteral - .codeExact("true") - ) - ) + def callsEnableJS = + traversal + .where( + _.referencingIdentifiers.inCall + .nameExact("getSettings") + .where( + _.inCall + .nameExact("setJavaScriptEnabled") + .argument + .isLiteral + .codeExact("true") + ) + ) - def loadUrlCalls = - traversal.referencingIdentifiers.inCall.nameExact("loadUrl") + def loadUrlCalls = + traversal.referencingIdentifiers.inCall.nameExact("loadUrl") - def addJavascriptInterfaceCalls = - traversal.referencingIdentifiers.inCall.nameExact("addJavascriptInterface") + def addJavascriptInterfaceCalls = + traversal.referencingIdentifiers.inCall.nameExact("addJavascriptInterface") end LocalTraversal diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/android/MethodTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/android/MethodTraversal.scala index b4c05be0..35edd480 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/android/MethodTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/android/MethodTraversal.scala @@ -4,5 +4,5 @@ import io.shiftleft.codepropertygraph.generated.nodes import io.shiftleft.semanticcpg.language.* class MethodTraversal(val traversal: Iterator[nodes.Method]) extends AnyVal: - def exposedToJS = - traversal.where(_.annotation.fullNameExact("android.webkit.JavascriptInterface")) + def exposedToJS = + traversal.where(_.annotation.fullNameExact("android.webkit.JavascriptInterface")) diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/android/NodeTypeStarters.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/android/NodeTypeStarters.scala index cdb97d8b..1c1bdb9e 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/android/NodeTypeStarters.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/android/NodeTypeStarters.scala @@ -5,36 +5,36 @@ import io.shiftleft.codepropertygraph.generated.nodes.* import io.shiftleft.semanticcpg.language.* class NodeTypeStarters(cpg: Cpg): - def webView: Iterator[Local] = - cpg.local.typeFullNameExact("android.webkit.WebView") - - def appManifest: Iterator[ConfigFile] = - cpg.configFile.filter(_.name.endsWith(Constants.androidManifestXml)) - - def getExternalStorageDir: Iterator[Call] = - cpg.call - .nameExact("getExternalStorageDirectory") - .where(_.argument(0).isIdentifier.typeFullNameExact("android.os.Environment")) - - def dexClassLoader: Iterator[Local] = - cpg.local.typeFullNameExact("dalvik.system.DexClassLoader") - - def broadcastReceivers: Iterator[TypeDecl] = - cpg.method - .nameExact("onReceive") - .where(_.parameter.index(1).typeFullNameExact("android.content.Context")) - .where(_.parameter.index(2).typeFullNameExact("android.content.Intent")) - .typeDecl - - def registerReceiver: Iterator[Call] = - cpg.call - .nameExact("registerReceiver") - .where(_.argument(2).isIdentifier.typeFullNameExact("android.content.IntentFilter")) - - def registeredBroadcastReceivers = - cpg.broadcastReceivers.filter { broadcastReceiver => - cpg.registerReceiver.argument(1).isIdentifier.typeFullName.exists( - _ == broadcastReceiver.fullName - ) - } + def webView: Iterator[Local] = + cpg.local.typeFullNameExact("android.webkit.WebView") + + def appManifest: Iterator[ConfigFile] = + cpg.configFile.filter(_.name.endsWith(Constants.androidManifestXml)) + + def getExternalStorageDir: Iterator[Call] = + cpg.call + .nameExact("getExternalStorageDirectory") + .where(_.argument(0).isIdentifier.typeFullNameExact("android.os.Environment")) + + def dexClassLoader: Iterator[Local] = + cpg.local.typeFullNameExact("dalvik.system.DexClassLoader") + + def broadcastReceivers: Iterator[TypeDecl] = + cpg.method + .nameExact("onReceive") + .where(_.parameter.index(1).typeFullNameExact("android.content.Context")) + .where(_.parameter.index(2).typeFullNameExact("android.content.Intent")) + .typeDecl + + def registerReceiver: Iterator[Call] = + cpg.call + .nameExact("registerReceiver") + .where(_.argument(2).isIdentifier.typeFullNameExact("android.content.IntentFilter")) + + def registeredBroadcastReceivers = + cpg.broadcastReceivers.filter { broadcastReceiver => + cpg.registerReceiver.argument(1).isIdentifier.typeFullName.exists( + _ == broadcastReceiver.fullName + ) + } end NodeTypeStarters diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/android/package.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/android/package.scala index 6c0dc3c9..f790127f 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/android/package.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/android/package.scala @@ -6,24 +6,24 @@ import io.shiftleft.codepropertygraph.generated.nodes.{ConfigFile, Literal, Loca /** Language extensions for android. */ package object android: - implicit def toNodeTypeStartersFlows(cpg: Cpg): NodeTypeStarters = - new NodeTypeStarters(cpg) + implicit def toNodeTypeStartersFlows(cpg: Cpg): NodeTypeStarters = + new NodeTypeStarters(cpg) - implicit def singleToLocalExt[A <: Local](a: A): LocalTraversal = - new LocalTraversal(Iterator.single(a)) + implicit def singleToLocalExt[A <: Local](a: A): LocalTraversal = + new LocalTraversal(Iterator.single(a)) - implicit def iterOnceToLocalExt[A <: Local](a: IterableOnce[A]): LocalTraversal = - new LocalTraversal(a.iterator) + implicit def iterOnceToLocalExt[A <: Local](a: IterableOnce[A]): LocalTraversal = + new LocalTraversal(a.iterator) - implicit def singleToConfigFileExt[A <: ConfigFile](a: A): ConfigFileTraversal = - new ConfigFileTraversal(Iterator.single(a)) + implicit def singleToConfigFileExt[A <: ConfigFile](a: A): ConfigFileTraversal = + new ConfigFileTraversal(Iterator.single(a)) - implicit def iterOnceToConfigFileExt[A <: ConfigFile](a: IterableOnce[A]): ConfigFileTraversal = - new ConfigFileTraversal(a.iterator) + implicit def iterOnceToConfigFileExt[A <: ConfigFile](a: IterableOnce[A]): ConfigFileTraversal = + new ConfigFileTraversal(a.iterator) - implicit def singleToMethodExt[A <: Method](a: A): MethodTraversal = - new MethodTraversal(Iterator.single(a)) + implicit def singleToMethodExt[A <: Method](a: A): MethodTraversal = + new MethodTraversal(Iterator.single(a)) - implicit def iterOnceToMethodExt[A <: Method](a: IterableOnce[A]): MethodTraversal = - new MethodTraversal(a.iterator) + implicit def iterOnceToMethodExt[A <: Method](a: IterableOnce[A]): MethodTraversal = + new MethodTraversal(a.iterator) end android diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/bindingextension/MethodTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/bindingextension/MethodTraversal.scala index 9e222c97..1216c560 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/bindingextension/MethodTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/bindingextension/MethodTraversal.scala @@ -5,12 +5,12 @@ import io.shiftleft.semanticcpg.language.* class MethodTraversal(val traversal: Iterator[Method]) extends AnyVal: - /** Traverse to type decl which have this method bound to it. - */ - def bindingTypeDecl: Iterator[TypeDecl] = - referencingBinding.bindingTypeDecl + /** Traverse to type decl which have this method bound to it. + */ + def bindingTypeDecl: Iterator[TypeDecl] = + referencingBinding.bindingTypeDecl - /** Traverse to bindings which reference to this method. - */ - def referencingBinding: Iterator[Binding] = - traversal.flatMap(_._bindingViaRefIn) + /** Traverse to bindings which reference to this method. + */ + def referencingBinding: Iterator[Binding] = + traversal.flatMap(_._bindingViaRefIn) diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/bindingextension/TypeDeclTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/bindingextension/TypeDeclTraversal.scala index be9c73df..9143b384 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/bindingextension/TypeDeclTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/bindingextension/TypeDeclTraversal.scala @@ -5,12 +5,12 @@ import io.shiftleft.semanticcpg.language.* class TypeDeclTraversal(val traversal: Iterator[TypeDecl]) extends AnyVal: - /** Traverse to methods bound to this type decl. - */ - def boundMethod: Iterator[Method] = - methodBinding.boundMethod + /** Traverse to methods bound to this type decl. + */ + def boundMethod: Iterator[Method] = + methodBinding.boundMethod - /** Traverse to the method bindings of this type declaration. - */ - def methodBinding: Iterator[Binding] = - traversal.canonicalType.flatMap(_.bindsOut) + /** Traverse to the method bindings of this type declaration. + */ + def methodBinding: Iterator[Binding] = + traversal.canonicalType.flatMap(_.bindsOut) diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/callgraphextension/CallTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/callgraphextension/CallTraversal.scala index e4b18426..2dda6253 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/callgraphextension/CallTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/callgraphextension/CallTraversal.scala @@ -5,12 +5,12 @@ import io.shiftleft.semanticcpg.language.* class CallTraversal(val traversal: Iterator[Call]) extends AnyVal: - @deprecated("Use callee", "") - def calledMethod(implicit callResolver: ICallResolver): Iterator[Method] = callee + @deprecated("Use callee", "") + def calledMethod(implicit callResolver: ICallResolver): Iterator[Method] = callee - /** The callee method */ - def callee(implicit callResolver: ICallResolver): Iterator[Method] = - traversal.flatMap(callResolver.getCalledMethodsAsTraversal) + /** The callee method */ + def callee(implicit callResolver: ICallResolver): Iterator[Method] = + traversal.flatMap(callResolver.getCalledMethodsAsTraversal) - def referencedImports: Iterator[Import] = - traversal.flatMap(_._importViaIsCallForImportOut) + def referencedImports: Iterator[Import] = + traversal.flatMap(_._importViaIsCallForImportOut) diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/callgraphextension/MethodTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/callgraphextension/MethodTraversal.scala index 911703c8..24ba1b67 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/callgraphextension/MethodTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/callgraphextension/MethodTraversal.scala @@ -6,67 +6,67 @@ import overflowdb.traversal.help.Doc class MethodTraversal(val traversal: Iterator[Method]) extends AnyVal: - /** Intended for internal use! Traverse to direct and transitive callers of the method. - */ - def calledByIncludingSink(sourceTrav: Iterator[Method])(implicit - callResolver: ICallResolver - ): Iterator[Method] = - val sourceMethods = sourceTrav.toSet - val sinkMethods = traversal.dedup + /** Intended for internal use! Traverse to direct and transitive callers of the method. + */ + def calledByIncludingSink(sourceTrav: Iterator[Method])(implicit + callResolver: ICallResolver + ): Iterator[Method] = + val sourceMethods = sourceTrav.toSet + val sinkMethods = traversal.dedup - if sourceMethods.isEmpty || sinkMethods.isEmpty then - Iterator.empty[Method].enablePathTracking - else - sinkMethods - .repeat( - _.flatMap( - callResolver.getMethodCallsitesAsTraversal - )._containsIn // expand to method - )(_.dedup.emit(_.collect { - case method: Method if sourceMethods.contains(method) => method - })) - .cast[Method] + if sourceMethods.isEmpty || sinkMethods.isEmpty then + Iterator.empty[Method].enablePathTracking + else + sinkMethods + .repeat( + _.flatMap( + callResolver.getMethodCallsitesAsTraversal + )._containsIn // expand to method + )(_.dedup.emit(_.collect { + case method: Method if sourceMethods.contains(method) => method + })) + .cast[Method] - /** Traverse to direct callers of this method - */ - def caller(implicit callResolver: ICallResolver): Iterator[Method] = - callIn(callResolver).method + /** Traverse to direct callers of this method + */ + def caller(implicit callResolver: ICallResolver): Iterator[Method] = + callIn(callResolver).method - /** Traverse to methods called by this method - */ - def callee(implicit callResolver: ICallResolver): Iterator[Method] = - call.callee(callResolver) + /** Traverse to methods called by this method + */ + def callee(implicit callResolver: ICallResolver): Iterator[Method] = + call.callee(callResolver) - /** Incoming call sites - */ - def callIn(implicit callResolver: ICallResolver): Iterator[Call] = - traversal.flatMap(method => - callResolver.getMethodCallsitesAsTraversal(method).collectAll[Call] - ) + /** Incoming call sites + */ + def callIn(implicit callResolver: ICallResolver): Iterator[Call] = + traversal.flatMap(method => + callResolver.getMethodCallsitesAsTraversal(method).collectAll[Call] + ) - /** Traverse to direct and transitive callers of the method. - */ - def calledBy(sourceTrav: Iterator[Method])(implicit - callResolver: ICallResolver - ): Iterator[Method] = - caller(callResolver).calledByIncludingSink(sourceTrav)(callResolver) + /** Traverse to direct and transitive callers of the method. + */ + def calledBy(sourceTrav: Iterator[Method])(implicit + callResolver: ICallResolver + ): Iterator[Method] = + caller(callResolver).calledByIncludingSink(sourceTrav)(callResolver) - @deprecated("Use call", "") - def callOut: Iterator[Call] = - call + @deprecated("Use call", "") + def callOut: Iterator[Call] = + call - @deprecated("Use call", "") - def callOutRegex(regex: String)(implicit callResolver: ICallResolver): Iterator[Call] = - call(regex) + @deprecated("Use call", "") + def callOutRegex(regex: String)(implicit callResolver: ICallResolver): Iterator[Call] = + call(regex) - /** Outgoing call sites to methods where fullName matches `regex`. - */ - def call(regex: String)(implicit callResolver: ICallResolver): Iterator[Call] = - call.where(_.callee.fullName(regex)) + /** Outgoing call sites to methods where fullName matches `regex`. + */ + def call(regex: String)(implicit callResolver: ICallResolver): Iterator[Call] = + call.where(_.callee.fullName(regex)) - /** Outgoing call sites - */ - @Doc(info = "Call sites (outgoing calls)") - def call: Iterator[Call] = - traversal.flatMap(_._callViaContainsOut) + /** Outgoing call sites + */ + @Doc(info = "Call sites (outgoing calls)") + def call: Iterator[Call] = + traversal.flatMap(_._callViaContainsOut) end MethodTraversal diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/dotextension/AstNodeDot.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/dotextension/AstNodeDot.scala index eee4a717..72c02685 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/dotextension/AstNodeDot.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/dotextension/AstNodeDot.scala @@ -6,7 +6,7 @@ import overflowdb.traversal.* class AstNodeDot[NodeType <: AstNode](val traversal: Iterator[NodeType]) extends AnyVal: - def dotAst: Iterator[String] = DotAstGenerator.dotAst(traversal) + def dotAst: Iterator[String] = DotAstGenerator.dotAst(traversal) - def plotDotAst(implicit viewer: ImageViewer): Unit = - Shared.plotAndDisplay(dotAst.l, viewer) + def plotDotAst(implicit viewer: ImageViewer): Unit = + Shared.plotAndDisplay(dotAst.l, viewer) diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/dotextension/CfgNodeDot.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/dotextension/CfgNodeDot.scala index 3da7357f..03c50848 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/dotextension/CfgNodeDot.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/dotextension/CfgNodeDot.scala @@ -6,12 +6,12 @@ import overflowdb.traversal.* class CfgNodeDot(val traversal: Iterator[Method]) extends AnyVal: - def dotCfg: Iterator[String] = DotCfgGenerator.dotCfg(traversal) + def dotCfg: Iterator[String] = DotCfgGenerator.dotCfg(traversal) - def dotCdg: Iterator[String] = DotCdgGenerator.dotCdg(traversal) + def dotCdg: Iterator[String] = DotCdgGenerator.dotCdg(traversal) - def plotDotCfg(implicit viewer: ImageViewer): Unit = - Shared.plotAndDisplay(dotCfg.l, viewer) + def plotDotCfg(implicit viewer: ImageViewer): Unit = + Shared.plotAndDisplay(dotCfg.l, viewer) - def plotDotCdg(implicit viewer: ImageViewer): Unit = - Shared.plotAndDisplay(dotCdg.l, viewer) + def plotDotCdg(implicit viewer: ImageViewer): Unit = + Shared.plotAndDisplay(dotCdg.l, viewer) diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/dotextension/InterproceduralNodeDot.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/dotextension/InterproceduralNodeDot.scala index fff4c131..5edbcfd8 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/dotextension/InterproceduralNodeDot.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/dotextension/InterproceduralNodeDot.scala @@ -5,6 +5,6 @@ import io.shiftleft.semanticcpg.dotgenerator.{DotCallGraphGenerator, DotTypeHier class InterproceduralNodeDot(val cpg: Cpg) extends AnyVal: - def dotCallGraph: Iterator[String] = DotCallGraphGenerator.dotCallGraph(cpg) + def dotCallGraph: Iterator[String] = DotCallGraphGenerator.dotCallGraph(cpg) - def dotTypeHierarchy: Iterator[String] = DotTypeHierarchyGenerator.dotTypeHierarchy(cpg) + def dotTypeHierarchy: Iterator[String] = DotTypeHierarchyGenerator.dotTypeHierarchy(cpg) diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/dotextension/Shared.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/dotextension/Shared.scala index a58e70c5..90f66187 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/dotextension/Shared.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/dotextension/Shared.scala @@ -6,35 +6,35 @@ import scala.sys.process.Process import scala.util.{Failure, Success, Try} trait ImageViewer: - def view(pathStr: String): Try[String] + def view(pathStr: String): Try[String] object Shared: - def plotAndDisplay(dotStrings: List[String], viewer: ImageViewer): Unit = - dotStrings.foreach { dotString => - File.usingTemporaryFile("semanticcpg") { dotFile => - File.usingTemporaryFile("semanticcpg") { svgFile => - dotFile.write(dotString) - createSvgFile(dotFile, svgFile).toOption.foreach(_ => - viewer.view(svgFile.path.toAbsolutePath.toString) - ) - } - } - } + def plotAndDisplay(dotStrings: List[String], viewer: ImageViewer): Unit = + dotStrings.foreach { dotString => + File.usingTemporaryFile("semanticcpg") { dotFile => + File.usingTemporaryFile("semanticcpg") { svgFile => + dotFile.write(dotString) + createSvgFile(dotFile, svgFile).toOption.foreach(_ => + viewer.view(svgFile.path.toAbsolutePath.toString) + ) + } + } + } - private def createSvgFile(in: File, out: File): Try[String] = - Try { - Process(Seq( - "dot", - "-Tsvg", - in.path.toAbsolutePath.toString, - "-o", - out.path.toAbsolutePath.toString - )).!! - } match - case Success(v) => Success(v) - case Failure(exc) => - System.err.println("Executing `dot` failed: is `graphviz` installed?") - System.err.println(exc) - Failure(exc) + private def createSvgFile(in: File, out: File): Try[String] = + Try { + Process(Seq( + "dot", + "-Tsvg", + in.path.toAbsolutePath.toString, + "-o", + out.path.toAbsolutePath.toString + )).!! + } match + case Success(v) => Success(v) + case Failure(exc) => + System.err.println("Executing `dot` failed: is `graphviz` installed?") + System.err.println(exc) + Failure(exc) end Shared diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/AstNodeMethods.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/AstNodeMethods.scala index 1c5a17ce..7240ca83 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/AstNodeMethods.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/AstNodeMethods.scala @@ -9,133 +9,133 @@ import io.shiftleft.semanticcpg.utils.MemberAccess class AstNodeMethods(val node: AstNode) extends AnyVal with NodeExtension: - /** Indicate whether the AST node represents a control structure, e.g., `if`, `for`, `while`. - */ - def isControlStructure: Boolean = node.isInstanceOf[ControlStructure] + /** Indicate whether the AST node represents a control structure, e.g., `if`, `for`, `while`. + */ + def isControlStructure: Boolean = node.isInstanceOf[ControlStructure] - def isIdentifier: Boolean = node.isInstanceOf[Identifier] + def isIdentifier: Boolean = node.isInstanceOf[Identifier] - def isImport: Boolean = node.isInstanceOf[Import] + def isImport: Boolean = node.isInstanceOf[Import] - def isFieldIdentifier: Boolean = node.isInstanceOf[FieldIdentifier] + def isFieldIdentifier: Boolean = node.isInstanceOf[FieldIdentifier] - def isFile: Boolean = node.isInstanceOf[File] + def isFile: Boolean = node.isInstanceOf[File] - def isReturn: Boolean = node.isInstanceOf[Return] + def isReturn: Boolean = node.isInstanceOf[Return] - def isLiteral: Boolean = node.isInstanceOf[Literal] + def isLiteral: Boolean = node.isInstanceOf[Literal] - def isLocal: Boolean = node.isInstanceOf[Local] + def isLocal: Boolean = node.isInstanceOf[Local] - def isCall: Boolean = node.isInstanceOf[Call] + def isCall: Boolean = node.isInstanceOf[Call] - def isExpression: Boolean = node.isInstanceOf[Expression] + def isExpression: Boolean = node.isInstanceOf[Expression] - def isMember: Boolean = node.isInstanceOf[Member] + def isMember: Boolean = node.isInstanceOf[Member] - def isMethodRef: Boolean = node.isInstanceOf[MethodRef] + def isMethodRef: Boolean = node.isInstanceOf[MethodRef] - def isMethod: Boolean = node.isInstanceOf[Method] + def isMethod: Boolean = node.isInstanceOf[Method] - def isModifier: Boolean = node.isInstanceOf[Modifier] + def isModifier: Boolean = node.isInstanceOf[Modifier] - def isNamespaceBlock: Boolean = node.isInstanceOf[NamespaceBlock] + def isNamespaceBlock: Boolean = node.isInstanceOf[NamespaceBlock] - def isBlock: Boolean = node.isInstanceOf[Block] + def isBlock: Boolean = node.isInstanceOf[Block] - def isParameter: Boolean = node.isInstanceOf[MethodParameterIn] + def isParameter: Boolean = node.isInstanceOf[MethodParameterIn] - def isTypeDecl: Boolean = node.isInstanceOf[TypeDecl] + def isTypeDecl: Boolean = node.isInstanceOf[TypeDecl] - def depth: Int = depth(_ => true) + def depth: Int = depth(_ => true) - /** The depth of the AST rooted in this node. Upon walking the tree to its leaves, the depth is - * only increased for nodes where `p(node)` is true. - */ - def depth(p: AstNode => Boolean): Int = - val additionalDepth = if p(node) then 1 - else 0 + /** The depth of the AST rooted in this node. Upon walking the tree to its leaves, the depth is + * only increased for nodes where `p(node)` is true. + */ + def depth(p: AstNode => Boolean): Int = + val additionalDepth = if p(node) then 1 + else 0 - val childDepths = node.astChildren.map(_.depth(p)).l - additionalDepth + (if childDepths.isEmpty then - 0 - else - childDepths.max - ) + val childDepths = node.astChildren.map(_.depth(p)).l + additionalDepth + (if childDepths.isEmpty then + 0 + else + childDepths.max + ) - def astParent: AstNode = - try - node._astIn.onlyChecked.asInstanceOf[AstNode] - catch - case _: Throwable => node._astIn.asInstanceOf[AstNode] + def astParent: AstNode = + try + node._astIn.onlyChecked.asInstanceOf[AstNode] + catch + case _: Throwable => node._astIn.asInstanceOf[AstNode] - /** Direct children of node in the AST. Siblings are ordered by their `order` fields - */ - def astChildren: Iterator[AstNode] = - node._astOut.cast[AstNode].sortBy(_.order).iterator + /** Direct children of node in the AST. Siblings are ordered by their `order` fields + */ + def astChildren: Iterator[AstNode] = + node._astOut.cast[AstNode].sortBy(_.order).iterator - /** Siblings of this node in the AST, ordered by their `order` fields - */ - def astSiblings: Iterator[AstNode] = - astParent.astChildren.filter(_ != node) + /** Siblings of this node in the AST, ordered by their `order` fields + */ + def astSiblings: Iterator[AstNode] = + astParent.astChildren.filter(_ != node) - /** Nodes of the AST rooted in this node, including the node itself. - */ - def ast: Iterator[AstNode] = - Iterator.single(node).ast + /** Nodes of the AST rooted in this node, including the node itself. + */ + def ast: Iterator[AstNode] = + Iterator.single(node).ast - /** Textual representation of AST node - */ - def repr: String = - node match - case method: Method => method.name - case member: Member => member.name - case methodReturn: MethodReturn => methodReturn.code - case expr: Expression => expr.code - case call: CallRepr if !call.isInstanceOf[Call] => call.code + /** Textual representation of AST node + */ + def repr: String = + node match + case method: Method => method.name + case member: Member => member.name + case methodReturn: MethodReturn => methodReturn.code + case expr: Expression => expr.code + case call: CallRepr if !call.isInstanceOf[Call] => call.code - def statement: AstNode = - statementInternal(node, _.parentExpression.get) + def statement: AstNode = + statementInternal(node, _.parentExpression.get) - @scala.annotation.tailrec - private def statementInternal( - node: AstNode, - parentExpansion: Expression => Expression - ): AstNode = - node match - case node: Identifier => parentExpansion(node) - case node: MethodRef => parentExpansion(node) - case node: TypeRef => parentExpansion(node) - case node: Literal => parentExpansion(node) + @scala.annotation.tailrec + private def statementInternal( + node: AstNode, + parentExpansion: Expression => Expression + ): AstNode = + node match + case node: Identifier => parentExpansion(node) + case node: MethodRef => parentExpansion(node) + case node: TypeRef => parentExpansion(node) + case node: Literal => parentExpansion(node) - case member: Member => member - case node: MethodParameterIn => node.method + case member: Member => member + case node: MethodParameterIn => node.method - case node: MethodParameterOut => - node.method.methodReturn + case node: MethodParameterOut => + node.method.methodReturn - case node: Call if MemberAccess.isGenericMemberAccessName(node.name) => - parentExpansion(node) + case node: Call if MemberAccess.isGenericMemberAccessName(node.name) => + parentExpansion(node) - case node: CallRepr => node - case node: MethodReturn => node - case block: Block => - // Just taking the lastExpressionInBlock is not quite correct because a BLOCK could have - // different return expressions. So we would need to expand via CFG. - // But currently the frontends do not even put the BLOCK into the CFG so this is the best - // we can do. - statementInternal(lastExpressionInBlock(block).get, identity) - case node: Expression => node + case node: CallRepr => node + case node: MethodReturn => node + case block: Block => + // Just taking the lastExpressionInBlock is not quite correct because a BLOCK could have + // different return expressions. So we would need to expand via CFG. + // But currently the frontends do not even put the BLOCK into the CFG so this is the best + // we can do. + statementInternal(lastExpressionInBlock(block).get, identity) + case node: Expression => node end AstNodeMethods object AstNodeMethods: - private def lastExpressionInBlock(block: Block): Option[Expression] = - block._astOut - .collect { - case node: Expression if !node.isInstanceOf[Local] && !node.isInstanceOf[Method] => - node - } - .toVector - .sortBy(_.order) - .lastOption + private def lastExpressionInBlock(block: Block): Option[Expression] = + block._astOut + .collect { + case node: Expression if !node.isInstanceOf[Local] && !node.isInstanceOf[Method] => + node + } + .toVector + .sortBy(_.order) + .lastOption diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/CallMethods.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/CallMethods.scala index edba3f54..89da9b87 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/CallMethods.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/CallMethods.scala @@ -5,26 +5,26 @@ import io.shiftleft.semanticcpg.NodeExtension import io.shiftleft.semanticcpg.language.* class CallMethods(val node: Call) extends AnyVal with NodeExtension with HasLocation: - def receiver: Iterator[Expression] = - node.receiverOut + def receiver: Iterator[Expression] = + node.receiverOut - def arguments(index: Int): Iterator[Expression] = - node._argumentOut - .collect { - case expr: Expression if expr.argumentIndex == index => expr - } + def arguments(index: Int): Iterator[Expression] = + node._argumentOut + .collect { + case expr: Expression if expr.argumentIndex == index => expr + } - def argument: Iterator[Expression] = - node._argumentOut.collectAll[Expression] + def argument: Iterator[Expression] = + node._argumentOut.collectAll[Expression] - def argument(index: Int): Expression = - arguments(index).head + def argument(index: Int): Expression = + arguments(index).head - def argumentOption(index: Int): Option[Expression] = - node._argumentOut.collectFirst { - case expr: Expression if expr.argumentIndex == index => expr - } + def argumentOption(index: Int): Option[Expression] = + node._argumentOut.collectFirst { + case expr: Expression if expr.argumentIndex == index => expr + } - override def location: NewLocation = - LocationCreator(node, node.code, node.label, node.lineNumber, node.method) + override def location: NewLocation = + LocationCreator(node, node.code, node.label, node.lineNumber, node.method) end CallMethods diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/CfgNodeMethods.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/CfgNodeMethods.scala index 185cba76..a18a7a1c 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/CfgNodeMethods.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/CfgNodeMethods.scala @@ -9,119 +9,118 @@ import scala.jdk.CollectionConverters.* class CfgNodeMethods(val node: CfgNode) extends AnyVal with NodeExtension: - /** Successors in the CFG - */ - def cfgNext: Iterator[CfgNode] = - Iterator.single(node).cfgNext - - /** Maps each node in the traversal to a traversal returning its n successors. - */ - def cfgNext(n: Int): Iterator[CfgNode] = n match - case 0 => Iterator.empty - case _ => cfgNext.flatMap(x => List(x) ++ x.cfgNext(n - 1)) - - /** Maps each node in the traversal to a traversal returning its n predecessors. - */ - def cfgPrev(n: Int): Iterator[CfgNode] = n match - case 0 => Iterator.empty - case _ => cfgPrev.flatMap(x => List(x) ++ x.cfgPrev(n - 1)) - - /** Predecessors in the CFG - */ - def cfgPrev: Iterator[CfgNode] = - Iterator.single(node).cfgPrev - - /** Recursively determine all nodes on which this CFG node is control-dependent. - */ - def controlledBy: Iterator[CfgNode] = - expandExhaustively { v => - v._cdgIn - } - - /** Recursively determine all nodes which this CFG node controls - */ - def controls: Iterator[CfgNode] = - expandExhaustively { v => - v._cdgOut - } - - /** Recursively determine all nodes by which this node is dominated - */ - def dominatedBy: Iterator[CfgNode] = - expandExhaustively { v => - v._dominateIn - } - - /** Recursively determine all nodes which are dominated by this node - */ - def dominates: Iterator[CfgNode] = - expandExhaustively { v => - v._dominateOut - } - - /** Recursively determine all nodes by which this node is post dominated - */ - def postDominatedBy: Iterator[CfgNode] = - expandExhaustively { v => - v._postDominateIn - } - - /** Recursively determine all nodes which are post dominated by this node - */ - def postDominates: Iterator[CfgNode] = - expandExhaustively { v => - v._postDominateOut - } - - private def expandExhaustively(expand: CfgNode => Iterator[StoredNode]): Iterator[CfgNode] = - var controllingNodes = List.empty[CfgNode] - var visited = Set.empty + node - var worklist = node :: Nil - - while worklist.nonEmpty do - val vertex = worklist.head - worklist = worklist.tail - - expand(vertex).foreach { case controllingNode: CfgNode => - if !visited.contains(controllingNode) then - visited += controllingNode - controllingNodes = controllingNode :: controllingNodes - worklist = controllingNode :: worklist - } - controllingNodes.iterator - - def method: Method = node match - case node: Method => node - case _: MethodParameterIn | _: MethodParameterOut | _: MethodReturn => - walkUpAst(node) - case _: CallRepr if !node.isInstanceOf[Call] => walkUpAst(node) - case _: Annotation | _: AnnotationLiteral => node.inAst.collectAll[Method].head - case _: Expression | _: JumpTarget => walkUpContains(node) - - /** Obtain hexadecimal string representation of lineNumber field. - * - * Binary frontends store addresses in the lineNumber field as integers. For interoperability - * with other binary analysis tooling, it is convenient to allow retrieving these as hex - * strings. - */ - def address: Option[String] = - node.lineNumber.map(_.toLong.toHexString) - - private def walkUpAst(node: CfgNode): Method = - node._astIn.onlyChecked.asInstanceOf[Method] - - private def walkUpContains(node: StoredNode): Method = - node._containsIn.onlyChecked match - case method: Method => method - case typeDecl: TypeDecl => - typeDecl.astParent match - case namespaceBlock: NamespaceBlock => - // For Typescript, types may be declared in namespaces which we represent as NamespaceBlocks - namespaceBlock.inAst.collectAll[Method].headOption.orNull - case method: Method => - // For a language such as Javascript, types may be dynamically declared under procedures - method - case _ => - // there are csharp CPGs that have typedecls here, which is invalid. - null + /** Successors in the CFG + */ + def cfgNext: Iterator[CfgNode] = + Iterator.single(node).cfgNext + + /** Maps each node in the traversal to a traversal returning its n successors. + */ + def cfgNext(n: Int): Iterator[CfgNode] = n match + case 0 => Iterator.empty + case _ => cfgNext.flatMap(x => List(x) ++ x.cfgNext(n - 1)) + + /** Maps each node in the traversal to a traversal returning its n predecessors. + */ + def cfgPrev(n: Int): Iterator[CfgNode] = n match + case 0 => Iterator.empty + case _ => cfgPrev.flatMap(x => List(x) ++ x.cfgPrev(n - 1)) + + /** Predecessors in the CFG + */ + def cfgPrev: Iterator[CfgNode] = + Iterator.single(node).cfgPrev + + /** Recursively determine all nodes on which this CFG node is control-dependent. + */ + def controlledBy: Iterator[CfgNode] = + expandExhaustively { v => + v._cdgIn + } + + /** Recursively determine all nodes which this CFG node controls + */ + def controls: Iterator[CfgNode] = + expandExhaustively { v => + v._cdgOut + } + + /** Recursively determine all nodes by which this node is dominated + */ + def dominatedBy: Iterator[CfgNode] = + expandExhaustively { v => + v._dominateIn + } + + /** Recursively determine all nodes which are dominated by this node + */ + def dominates: Iterator[CfgNode] = + expandExhaustively { v => + v._dominateOut + } + + /** Recursively determine all nodes by which this node is post dominated + */ + def postDominatedBy: Iterator[CfgNode] = + expandExhaustively { v => + v._postDominateIn + } + + /** Recursively determine all nodes which are post dominated by this node + */ + def postDominates: Iterator[CfgNode] = + expandExhaustively { v => + v._postDominateOut + } + + private def expandExhaustively(expand: CfgNode => Iterator[StoredNode]): Iterator[CfgNode] = + var controllingNodes = List.empty[CfgNode] + var visited = Set.empty + node + var worklist = node :: Nil + + while worklist.nonEmpty do + val vertex = worklist.head + worklist = worklist.tail + + expand(vertex).foreach { case controllingNode: CfgNode => + if !visited.contains(controllingNode) then + visited += controllingNode + controllingNodes = controllingNode :: controllingNodes + worklist = controllingNode :: worklist + } + controllingNodes.iterator + + def method: Method = node match + case node: Method => node + case _: MethodParameterIn | _: MethodParameterOut | _: MethodReturn => + walkUpAst(node) + case _: CallRepr if !node.isInstanceOf[Call] => walkUpAst(node) + case _: Annotation | _: AnnotationLiteral => node.inAst.collectAll[Method].head + case _: Expression | _: JumpTarget => walkUpContains(node) + + /** Obtain hexadecimal string representation of lineNumber field. + * + * Binary frontends store addresses in the lineNumber field as integers. For interoperability + * with other binary analysis tooling, it is convenient to allow retrieving these as hex strings. + */ + def address: Option[String] = + node.lineNumber.map(_.toLong.toHexString) + + private def walkUpAst(node: CfgNode): Method = + node._astIn.onlyChecked.asInstanceOf[Method] + + private def walkUpContains(node: StoredNode): Method = + node._containsIn.onlyChecked match + case method: Method => method + case typeDecl: TypeDecl => + typeDecl.astParent match + case namespaceBlock: NamespaceBlock => + // For Typescript, types may be declared in namespaces which we represent as NamespaceBlocks + namespaceBlock.inAst.collectAll[Method].headOption.orNull + case method: Method => + // For a language such as Javascript, types may be dynamically declared under procedures + method + case _ => + // there are csharp CPGs that have typedecls here, which is invalid. + null end CfgNodeMethods diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/ExpressionMethods.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/ExpressionMethods.scala index 1df2e6e3..ae908fc0 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/ExpressionMethods.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/ExpressionMethods.scala @@ -15,54 +15,54 @@ import scala.jdk.CollectionConverters.* // got exposed and for now we do not want to break the API. class ExpressionMethods(val node: Expression) extends AnyVal with NodeExtension: - /** Traverse to it's parent expression (e.g. call or return) by following the incoming AST It's - * continuing it's walk until it hits an expression that's not a generic "member access - * operation", e.g., ".memberAccess". - */ - def parentExpression: Option[Expression] = _parentExpression(node) + /** Traverse to it's parent expression (e.g. call or return) by following the incoming AST It's + * continuing it's walk until it hits an expression that's not a generic "member access + * operation", e.g., ".memberAccess". + */ + def parentExpression: Option[Expression] = _parentExpression(node) - @tailrec - private final def _parentExpression(argument: AstNode): Option[Expression] = - val parent = argument._astIn.onlyChecked - parent match - case call: Call if MemberAccess.isGenericMemberAccessName(call.name) => - _parentExpression(call) - case expression: Expression => - Some(expression) - case annotationParameterAssign: AnnotationParameterAssign => - _parentExpression(annotationParameterAssign) - case _ => - None + @tailrec + private final def _parentExpression(argument: AstNode): Option[Expression] = + val parent = argument._astIn.onlyChecked + parent match + case call: Call if MemberAccess.isGenericMemberAccessName(call.name) => + _parentExpression(call) + case expression: Expression => + Some(expression) + case annotationParameterAssign: AnnotationParameterAssign => + _parentExpression(annotationParameterAssign) + case _ => + None - def expressionUp: Iterator[Expression] = - node._astIn.collectAll[Expression] + def expressionUp: Iterator[Expression] = + node._astIn.collectAll[Expression] - def expressionDown: Iterator[Expression] = - node._astOut.collectAll[Expression] + def expressionDown: Iterator[Expression] = + node._astOut.collectAll[Expression] - def receivedCall: Iterator[Call] = - node._receiverIn.cast[Call] + def receivedCall: Iterator[Call] = + node._receiverIn.cast[Call] - def isArgument: Iterator[Expression] = - if node._argumentIn.hasNext then Iterator.single(node) - else Iterator.empty + def isArgument: Iterator[Expression] = + if node._argumentIn.hasNext then Iterator.single(node) + else Iterator.empty - def inCall: Iterator[Call] = - node._argumentIn.headOption match - case Some(c: Call) => Iterator.single(c) - case _ => Iterator.empty + def inCall: Iterator[Call] = + node._argumentIn.headOption match + case Some(c: Call) => Iterator.single(c) + case _ => Iterator.empty - def parameter(implicit callResolver: ICallResolver): Iterator[MethodParameterIn] = - // Expressions can have incoming argument edges not just from CallRepr nodes but also - // from Return nodes for which an expansion to parameter makes no sense. So we filter - // for CallRepr. - for - call <- node._argumentIn if call.isInstanceOf[CallRepr] - calledMethods <- callResolver.getCalledMethods(call.asInstanceOf[CallRepr]) - paramIn <- calledMethods._astOut.collectAll[MethodParameterIn] - if paramIn.index == node.argumentIndex - yield paramIn + def parameter(implicit callResolver: ICallResolver): Iterator[MethodParameterIn] = + // Expressions can have incoming argument edges not just from CallRepr nodes but also + // from Return nodes for which an expansion to parameter makes no sense. So we filter + // for CallRepr. + for + call <- node._argumentIn if call.isInstanceOf[CallRepr] + calledMethods <- callResolver.getCalledMethods(call.asInstanceOf[CallRepr]) + paramIn <- calledMethods._astOut.collectAll[MethodParameterIn] + if paramIn.index == node.argumentIndex + yield paramIn - def typ: Iterator[Type] = - node._evalTypeOut.cast[Type] + def typ: Iterator[Type] = + node._evalTypeOut.cast[Type] end ExpressionMethods diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/IdentifierMethods.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/IdentifierMethods.scala index 3c9e6646..f212bc18 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/IdentifierMethods.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/IdentifierMethods.scala @@ -6,11 +6,11 @@ import io.shiftleft.semanticcpg.language.{HasLocation, LocationCreator, *} class IdentifierMethods(val identifier: Identifier) extends AnyVal with NodeExtension with HasLocation: - override def location: NewLocation = - LocationCreator( - identifier, - identifier.name, - identifier.label, - identifier.lineNumber, - identifier.method - ) + override def location: NewLocation = + LocationCreator( + identifier, + identifier.name, + identifier.label, + identifier.lineNumber, + identifier.method + ) diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/LiteralMethods.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/LiteralMethods.scala index c3e78471..14044b3b 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/LiteralMethods.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/LiteralMethods.scala @@ -5,5 +5,5 @@ import io.shiftleft.semanticcpg.NodeExtension import io.shiftleft.semanticcpg.language.{HasLocation, LocationCreator, *} class LiteralMethods(val literal: Literal) extends AnyVal with NodeExtension with HasLocation: - override def location: NewLocation = - LocationCreator(literal, literal.code, literal.label, literal.lineNumber, literal.method) + override def location: NewLocation = + LocationCreator(literal, literal.code, literal.label, literal.lineNumber, literal.method) diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/LocalMethods.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/LocalMethods.scala index d65271d5..95951b73 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/LocalMethods.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/LocalMethods.scala @@ -5,10 +5,10 @@ import io.shiftleft.semanticcpg.NodeExtension import io.shiftleft.semanticcpg.language.* class LocalMethods(val local: Local) extends AnyVal with NodeExtension with HasLocation: - override def location: NewLocation = - LocationCreator(local, local.name, local.label, local.lineNumber, local.method.head) + override def location: NewLocation = + LocationCreator(local, local.name, local.label, local.lineNumber, local.method.head) - /** The method hosting this local variable - */ - def method: Iterator[Method] = - Iterator.single(local).method + /** The method hosting this local variable + */ + def method: Iterator[Method] = + Iterator.single(local).method diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/MethodMethods.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/MethodMethods.scala index 7d832e4c..81f29fd0 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/MethodMethods.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/MethodMethods.scala @@ -6,59 +6,59 @@ import io.shiftleft.semanticcpg.language.* class MethodMethods(val method: Method) extends AnyVal with NodeExtension with HasLocation: - /** Traverse to annotations of method - */ - def annotation: Iterator[Annotation] = - method._annotationViaAstOut + /** Traverse to annotations of method + */ + def annotation: Iterator[Annotation] = + method._annotationViaAstOut - def local: Iterator[Local] = - method._blockViaContainsOut.local + def local: Iterator[Local] = + method._blockViaContainsOut.local - /** All control structures of this method - */ - def controlStructure: Iterator[ControlStructure] = - method.ast.isControlStructure + /** All control structures of this method + */ + def controlStructure: Iterator[ControlStructure] = + method.ast.isControlStructure - def numberOfLines: Int = - if method.lineNumber.isDefined && method.lineNumberEnd.isDefined then - method.lineNumberEnd.get - method.lineNumber.get + 1 - else - 0 + def numberOfLines: Int = + if method.lineNumber.isDefined && method.lineNumberEnd.isDefined then + method.lineNumberEnd.get - method.lineNumber.get + 1 + else + 0 - def isVariadic: Boolean = - method.parameter.exists(_.isVariadic) + def isVariadic: Boolean = + method.parameter.exists(_.isVariadic) - def cfgNode: Iterator[CfgNode] = - method._containsOut.collectAll[CfgNode] + def cfgNode: Iterator[CfgNode] = + method._containsOut.collectAll[CfgNode] - /** List of CFG nodes in reverse post order - */ - def reversePostOrder: Iterator[CfgNode] = - def expand(x: CfgNode) = x.cfgNext.iterator - NodeOrdering.reverseNodeList( - NodeOrdering.postOrderNumbering(method, expand).toList - ).iterator + /** List of CFG nodes in reverse post order + */ + def reversePostOrder: Iterator[CfgNode] = + def expand(x: CfgNode) = x.cfgNext.iterator + NodeOrdering.reverseNodeList( + NodeOrdering.postOrderNumbering(method, expand).toList + ).iterator - /** List of CFG nodes in post order - */ - def postOrder: Iterator[CfgNode] = - def expand(x: CfgNode) = x.cfgNext.iterator - NodeOrdering.nodeList(NodeOrdering.postOrderNumbering(method, expand).toList).iterator + /** List of CFG nodes in post order + */ + def postOrder: Iterator[CfgNode] = + def expand(x: CfgNode) = x.cfgNext.iterator + NodeOrdering.nodeList(NodeOrdering.postOrderNumbering(method, expand).toList).iterator - /** The type declaration associated with this method, e.g., the class it is defined in. - */ - def definingTypeDecl: Option[TypeDecl] = - Iterator.single(method).definingTypeDecl.headOption + /** The type declaration associated with this method, e.g., the class it is defined in. + */ + def definingTypeDecl: Option[TypeDecl] = + Iterator.single(method).definingTypeDecl.headOption - /** The type declaration associated with this method, e.g., the class it is defined in. Alias - * for 'definingTypeDecl' - */ - def typeDecl: Option[TypeDecl] = definingTypeDecl + /** The type declaration associated with this method, e.g., the class it is defined in. Alias for + * 'definingTypeDecl' + */ + def typeDecl: Option[TypeDecl] = definingTypeDecl - /** Traverse to method body (alias for `block`) */ - def body: Block = - method.block + /** Traverse to method body (alias for `block`) */ + def body: Block = + method.block - override def location: NewLocation = - LocationCreator(method, method.name, method.label, method.lineNumber, method) + override def location: NewLocation = + LocationCreator(method, method.name, method.label, method.lineNumber, method) end MethodMethods diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/MethodParameterInMethods.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/MethodParameterInMethods.scala index 9e87f382..48e7c578 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/MethodParameterInMethods.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/MethodParameterInMethods.scala @@ -6,5 +6,5 @@ import io.shiftleft.semanticcpg.language.{HasLocation, LocationCreator} class MethodParameterInMethods(val paramIn: MethodParameterIn) extends AnyVal with NodeExtension with HasLocation: - override def location: NewLocation = - LocationCreator(paramIn, paramIn.name, paramIn.label, paramIn.lineNumber, paramIn.method) + override def location: NewLocation = + LocationCreator(paramIn, paramIn.name, paramIn.label, paramIn.lineNumber, paramIn.method) diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/MethodParameterOutMethods.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/MethodParameterOutMethods.scala index 34623ff6..c65eb4c9 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/MethodParameterOutMethods.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/MethodParameterOutMethods.scala @@ -6,11 +6,11 @@ import io.shiftleft.semanticcpg.language.{HasLocation, LocationCreator} class MethodParameterOutMethods(val paramOut: MethodParameterOut) extends AnyVal with NodeExtension with HasLocation: - override def location: NewLocation = - LocationCreator( - paramOut, - paramOut.name, - paramOut.label, - paramOut.lineNumber, - paramOut.method - ) + override def location: NewLocation = + LocationCreator( + paramOut, + paramOut.name, + paramOut.label, + paramOut.lineNumber, + paramOut.method + ) diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/MethodRefMethods.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/MethodRefMethods.scala index a9ad7d9e..e5fc3b74 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/MethodRefMethods.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/MethodRefMethods.scala @@ -5,11 +5,11 @@ import io.shiftleft.semanticcpg.NodeExtension import io.shiftleft.semanticcpg.language.{HasLocation, LocationCreator} class MethodRefMethods(val methodRef: MethodRef) extends AnyVal with NodeExtension with HasLocation: - override def location: NewLocation = - LocationCreator( - methodRef, - methodRef.code, - methodRef.label, - methodRef.lineNumber, - methodRef._methodViaContainsIn.next() - ) + override def location: NewLocation = + LocationCreator( + methodRef, + methodRef.code, + methodRef.label, + methodRef.lineNumber, + methodRef._methodViaContainsIn.next() + ) diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/MethodReturnMethods.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/MethodReturnMethods.scala index 206b926a..52143dba 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/MethodReturnMethods.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/MethodReturnMethods.scala @@ -6,16 +6,16 @@ import io.shiftleft.semanticcpg.language.* class MethodReturnMethods(val node: MethodReturn) extends AnyVal with NodeExtension with HasLocation: - override def location: NewLocation = - LocationCreator(node, "$ret", node.label, node.lineNumber, node.method) + override def location: NewLocation = + LocationCreator(node, "$ret", node.label, node.lineNumber, node.method) - def returnUser(implicit callResolver: ICallResolver): Iterator[Call] = - val method = node._methodViaAstIn - val callsites = callResolver.getMethodCallsites(method) - // TODO for now we filter away all implicit calls because a change of the - // return type to CallRepr would lead to a break in the API aka - // the DSL steps which are subsequently allowed to be called. Before - // we addressed this we can only return Call instances. - callsites.collectAll[Call] + def returnUser(implicit callResolver: ICallResolver): Iterator[Call] = + val method = node._methodViaAstIn + val callsites = callResolver.getMethodCallsites(method) + // TODO for now we filter away all implicit calls because a change of the + // return type to CallRepr would lead to a break in the API aka + // the DSL steps which are subsequently allowed to be called. Before + // we addressed this we can only return Call instances. + callsites.collectAll[Call] - def typ: Iterator[Type] = node.evalTypeOut + def typ: Iterator[Type] = node.evalTypeOut diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/NodeMethods.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/NodeMethods.scala index dfa54eba..6ad89933 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/NodeMethods.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/NodeMethods.scala @@ -7,7 +7,7 @@ import overflowdb.NodeOrDetachedNode class NodeMethods(val node: NodeOrDetachedNode) extends AnyVal with NodeExtension: - def location(implicit finder: NodeExtensionFinder): NewLocation = - node match - case storedNode: StoredNode => LocationCreator(storedNode) - case _ => LocationCreator.emptyLocation("", None) + def location(implicit finder: NodeExtensionFinder): NewLocation = + node match + case storedNode: StoredNode => LocationCreator(storedNode) + case _ => LocationCreator.emptyLocation("", None) diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/StoredNodeMethods.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/StoredNodeMethods.scala index 658df335..219289d8 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/StoredNodeMethods.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/nodemethods/StoredNodeMethods.scala @@ -7,10 +7,10 @@ import io.shiftleft.semanticcpg.language.* import scala.jdk.CollectionConverters.* class StoredNodeMethods(val node: StoredNode) extends AnyVal with NodeExtension: - def tag: Iterator[Tag] = - node._taggedByOut - .cast[Tag] - .distinctBy(tag => (tag.name, tag.value)) + def tag: Iterator[Tag] = + node._taggedByOut + .cast[Tag] + .distinctBy(tag => (tag.name, tag.value)) - def file: Iterator[File] = - Iterator.single(node).file + def file: Iterator[File] = + Iterator.single(node).file diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/ArrayAccessTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/ArrayAccessTraversal.scala index 9e153997..dcbea5ff 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/ArrayAccessTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/ArrayAccessTraversal.scala @@ -6,26 +6,26 @@ import overflowdb.traversal.help.Doc class ArrayAccessTraversal(val traversal: Iterator[OpNodes.ArrayAccess]) extends AnyVal: - @Doc(info = "The expression representing the array") - def array: Iterator[Expression] = traversal.map(_.array) + @Doc(info = "The expression representing the array") + def array: Iterator[Expression] = traversal.map(_.array) - @Doc(info = "Offset at which the array is referenced (an expression)") - def offset: Iterator[Expression] = traversal.map(_.offset) + @Doc(info = "Offset at which the array is referenced (an expression)") + def offset: Iterator[Expression] = traversal.map(_.offset) - @Doc(info = "All identifiers that are part of the offset") - def subscript: Iterator[Identifier] = traversal.flatMap(_.subscript) + @Doc(info = "All identifiers that are part of the offset") + def subscript: Iterator[Identifier] = traversal.flatMap(_.subscript) - @Doc( - info = "Determine whether array access has constant offset", - longInfo = """ + @Doc( + info = "Determine whether array access has constant offset", + longInfo = """ Determine if array access is at constant numeric offset, e.g., `buf[10]` but not `buf[i + 10]`, and for simplicity, not even `buf[1+2]`, `buf[PROBABLY_A_CONSTANT]` or `buf[PROBABLY_A_CONSTANT + 1]`, or even `buf[PROBABLY_A_CONSTANT]`. """ - ) - def usesConstantOffset: Iterator[OpNodes.ArrayAccess] = traversal.filter(_.usesConstantOffset) + ) + def usesConstantOffset: Iterator[OpNodes.ArrayAccess] = traversal.filter(_.usesConstantOffset) - @Doc(info = "If `array` is a lone identifier, return its name") - def simpleName: Iterator[String] = traversal.flatMap(_.simpleName) + @Doc(info = "If `array` is a lone identifier, return its name") + def simpleName: Iterator[String] = traversal.flatMap(_.simpleName) end ArrayAccessTraversal diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/AssignmentTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/AssignmentTraversal.scala index 4b0966c7..283fe58a 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/AssignmentTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/AssignmentTraversal.scala @@ -8,8 +8,8 @@ import overflowdb.traversal.help.Doc @help.Traversal(elementType = classOf[OpNodes.Assignment]) class AssignmentTraversal(val traversal: Iterator[OpNodes.Assignment]) extends AnyVal: - @Doc(info = "Left-hand sides of assignments") - def target: Iterator[Expression] = traversal.map(_.target) + @Doc(info = "Left-hand sides of assignments") + def target: Iterator[Expression] = traversal.map(_.target) - @Doc(info = "Right-hand sides of assignments") - def source: Iterator[Expression] = traversal.map(_.source) + @Doc(info = "Right-hand sides of assignments") + def source: Iterator[Expression] = traversal.map(_.source) diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/FieldAccessTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/FieldAccessTraversal.scala index 1e58d4cc..43841840 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/FieldAccessTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/FieldAccessTraversal.scala @@ -6,17 +6,17 @@ import overflowdb.traversal.help.Doc class FieldAccessTraversal(val traversal: Iterator[OpNodes.FieldAccess]) extends AnyVal: - @Doc(info = "Attempts to resolve the type declaration for this field access") - def typeDecl: Iterator[TypeDecl] = - traversal.flatMap(_.typeDecl) + @Doc(info = "Attempts to resolve the type declaration for this field access") + def typeDecl: Iterator[TypeDecl] = + traversal.flatMap(_.typeDecl) - // TODO there are cases for the C++ frontend where argument(2) is a CALL or IDENTIFIER, - // and we are not handling them at the moment + // TODO there are cases for the C++ frontend where argument(2) is a CALL or IDENTIFIER, + // and we are not handling them at the moment - @Doc(info = "The identifier of the referenced field (right-hand side)") - def fieldIdentifier: Iterator[FieldIdentifier] = - traversal.flatMap(_.fieldIdentifier) + @Doc(info = "The identifier of the referenced field (right-hand side)") + def fieldIdentifier: Iterator[FieldIdentifier] = + traversal.flatMap(_.fieldIdentifier) - @Doc(info = "Attempts to resolve the member referenced by this field access") - def member: Iterator[Member] = - traversal.flatMap(_.member) + @Doc(info = "Attempts to resolve the member referenced by this field access") + def member: Iterator[Member] = + traversal.flatMap(_.member) diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/Implicits.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/Implicits.scala index 8b30c56c..eeac8440 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/Implicits.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/Implicits.scala @@ -5,30 +5,30 @@ import io.shiftleft.codepropertygraph.generated.nodes.{AstNode, Expression} import io.shiftleft.semanticcpg.language.operatorextension.nodemethods.* trait Implicits: - implicit def toNodeTypeStartersOperatorExtension(cpg: Cpg): NodeTypeStarters = - new NodeTypeStarters(cpg) + implicit def toNodeTypeStartersOperatorExtension(cpg: Cpg): NodeTypeStarters = + new NodeTypeStarters(cpg) - implicit def toArrayAccessExt(arrayAccess: OpNodes.ArrayAccess): ArrayAccessMethods = - new ArrayAccessMethods(arrayAccess) - implicit def toArrayAccessTrav(steps: Iterator[OpNodes.ArrayAccess]): ArrayAccessTraversal = - new ArrayAccessTraversal(steps) + implicit def toArrayAccessExt(arrayAccess: OpNodes.ArrayAccess): ArrayAccessMethods = + new ArrayAccessMethods(arrayAccess) + implicit def toArrayAccessTrav(steps: Iterator[OpNodes.ArrayAccess]): ArrayAccessTraversal = + new ArrayAccessTraversal(steps) - implicit def toFieldAccessExt(fieldAccess: OpNodes.FieldAccess): FieldAccessMethods = - new FieldAccessMethods(fieldAccess) - implicit def toFieldAccessTrav(steps: Iterator[OpNodes.FieldAccess]): FieldAccessTraversal = - new FieldAccessTraversal(steps) + implicit def toFieldAccessExt(fieldAccess: OpNodes.FieldAccess): FieldAccessMethods = + new FieldAccessMethods(fieldAccess) + implicit def toFieldAccessTrav(steps: Iterator[OpNodes.FieldAccess]): FieldAccessTraversal = + new FieldAccessTraversal(steps) - implicit def toAssignmentExt(assignment: OpNodes.Assignment): AssignmentMethods = - new AssignmentMethods(assignment) - implicit def toAssignmentTrav(steps: Iterator[OpNodes.Assignment]): AssignmentTraversal = - new AssignmentTraversal(steps) + implicit def toAssignmentExt(assignment: OpNodes.Assignment): AssignmentMethods = + new AssignmentMethods(assignment) + implicit def toAssignmentTrav(steps: Iterator[OpNodes.Assignment]): AssignmentTraversal = + new AssignmentTraversal(steps) - implicit def toTargetExt(call: Expression): TargetMethods = new TargetMethods(call) - implicit def toTargetTrav(steps: Iterator[Expression]): TargetTraversal = - new TargetTraversal(steps) + implicit def toTargetExt(call: Expression): TargetMethods = new TargetMethods(call) + implicit def toTargetTrav(steps: Iterator[Expression]): TargetTraversal = + new TargetTraversal(steps) - implicit def toOpAstNodeExt[A <: AstNode](node: A): OpAstNodeMethods[A] = - new OpAstNodeMethods(node) - implicit def toOpAstNodeTrav[A <: AstNode](steps: Iterator[A]): OpAstNodeTraversal[A] = - new OpAstNodeTraversal(steps) + implicit def toOpAstNodeExt[A <: AstNode](node: A): OpAstNodeMethods[A] = + new OpAstNodeMethods(node) + implicit def toOpAstNodeTrav[A <: AstNode](steps: Iterator[A]): OpAstNodeTraversal[A] = + new OpAstNodeTraversal(steps) end Implicits diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/NodeTypeStarters.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/NodeTypeStarters.scala index 74a50f60..9a4b962c 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/NodeTypeStarters.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/NodeTypeStarters.scala @@ -9,30 +9,30 @@ import overflowdb.traversal.help.{Doc, TraversalSource} @TraversalSource class NodeTypeStarters(cpg: Cpg): - @Doc(info = - "All assignments, including shorthand assignments that perform arithmetic (e.g., '+=')" - ) - def assignment: Iterator[OpNodes.Assignment] = - callsWithNameIn(allAssignmentTypes) - .map(new OpNodes.Assignment(_)) + @Doc(info = + "All assignments, including shorthand assignments that perform arithmetic (e.g., '+=')" + ) + def assignment: Iterator[OpNodes.Assignment] = + callsWithNameIn(allAssignmentTypes) + .map(new OpNodes.Assignment(_)) - @Doc(info = - "All arithmetic operations, including shorthand assignments that perform arithmetic (e.g., '+=')" - ) - def arithmetic: Iterator[OpNodes.Arithmetic] = - callsWithNameIn(allArithmeticTypes) - .map(new OpNodes.Arithmetic(_)) + @Doc(info = + "All arithmetic operations, including shorthand assignments that perform arithmetic (e.g., '+=')" + ) + def arithmetic: Iterator[OpNodes.Arithmetic] = + callsWithNameIn(allArithmeticTypes) + .map(new OpNodes.Arithmetic(_)) - @Doc(info = "All array accesses") - def arrayAccess: Iterator[OpNodes.ArrayAccess] = - callsWithNameIn(allArrayAccessTypes) - .map(new OpNodes.ArrayAccess(_)) + @Doc(info = "All array accesses") + def arrayAccess: Iterator[OpNodes.ArrayAccess] = + callsWithNameIn(allArrayAccessTypes) + .map(new OpNodes.ArrayAccess(_)) - @Doc(info = "Field accesses, both direct and indirect") - def fieldAccess: Iterator[OpNodes.FieldAccess] = - callsWithNameIn(allFieldAccessTypes) - .map(new OpNodes.FieldAccess(_)) + @Doc(info = "Field accesses, both direct and indirect") + def fieldAccess: Iterator[OpNodes.FieldAccess] = + callsWithNameIn(allFieldAccessTypes) + .map(new OpNodes.FieldAccess(_)) - private def callsWithNameIn(set: Set[String]) = - cpg.call.filter(x => set.contains(x.name)) + private def callsWithNameIn(set: Set[String]) = + cpg.call.filter(x => set.contains(x.name)) end NodeTypeStarters diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/OpAstNodeTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/OpAstNodeTraversal.scala index 3a7f04da..82ea8a16 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/OpAstNodeTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/OpAstNodeTraversal.scala @@ -6,28 +6,28 @@ import overflowdb.traversal.help.Doc class OpAstNodeTraversal[A <: AstNode](val traversal: Iterator[A]) extends AnyVal: - @Doc(info = "Any assignments that this node is a part of (traverse up)") - def assignment: Iterator[OpNodes.Assignment] = traversal.flatMap(_.assignment) + @Doc(info = "Any assignments that this node is a part of (traverse up)") + def assignment: Iterator[OpNodes.Assignment] = traversal.flatMap(_.assignment) - @Doc(info = "Arithmetic expressions nested in this tree") - def arithmetic: Iterator[OpNodes.Arithmetic] = traversal.flatMap(_.arithmetic) + @Doc(info = "Arithmetic expressions nested in this tree") + def arithmetic: Iterator[OpNodes.Arithmetic] = traversal.flatMap(_.arithmetic) - @Doc(info = "All array accesses") - def arrayAccess: Iterator[OpNodes.ArrayAccess] = traversal.flatMap(_.arrayAccess) + @Doc(info = "All array accesses") + def arrayAccess: Iterator[OpNodes.ArrayAccess] = traversal.flatMap(_.arrayAccess) - @Doc(info = "Field accesses, both direct and indirect") - def fieldAccess: Iterator[OpNodes.FieldAccess] = - traversal.flatMap(_.fieldAccess) + @Doc(info = "Field accesses, both direct and indirect") + def fieldAccess: Iterator[OpNodes.FieldAccess] = + traversal.flatMap(_.fieldAccess) - @Doc(info = "Any assignments that this node is a part of (traverse up)") - def inAssignment: Iterator[OpNodes.Assignment] = traversal.flatMap(_.inAssignment) + @Doc(info = "Any assignments that this node is a part of (traverse up)") + def inAssignment: Iterator[OpNodes.Assignment] = traversal.flatMap(_.inAssignment) - @Doc(info = "Any arithmetic expression that this node is a part of (traverse up)") - def inArithmetic: Iterator[OpNodes.Arithmetic] = traversal.flatMap(_.inArithmetic) + @Doc(info = "Any arithmetic expression that this node is a part of (traverse up)") + def inArithmetic: Iterator[OpNodes.Arithmetic] = traversal.flatMap(_.inArithmetic) - @Doc(info = "Any array access that this node is a part of (traverse up)") - def inArrayAccess: Iterator[OpNodes.ArrayAccess] = traversal.flatMap(_.inArrayAccess) + @Doc(info = "Any array access that this node is a part of (traverse up)") + def inArrayAccess: Iterator[OpNodes.ArrayAccess] = traversal.flatMap(_.inArrayAccess) - @Doc(info = "Any field access that this node is a part of (traverse up)") - def inFieldAccess: Iterator[OpNodes.FieldAccess] = traversal.flatMap(_.inFieldAccess) + @Doc(info = "Any field access that this node is a part of (traverse up)") + def inFieldAccess: Iterator[OpNodes.FieldAccess] = traversal.flatMap(_.inFieldAccess) end OpAstNodeTraversal diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/OpNodes.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/OpNodes.scala index 67ed7bdc..7cd1b8b0 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/OpNodes.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/OpNodes.scala @@ -3,7 +3,7 @@ package io.shiftleft.semanticcpg.language.operatorextension import io.shiftleft.codepropertygraph.generated.nodes.Call object OpNodes: - class Assignment(call: Call) extends Call(call.graph, call.id) - class Arithmetic(call: Call) extends Call(call.graph, call.id) - class ArrayAccess(call: Call) extends Call(call.graph, call.id) - class FieldAccess(call: Call) extends Call(call.graph, call.id) + class Assignment(call: Call) extends Call(call.graph, call.id) + class Arithmetic(call: Call) extends Call(call.graph, call.id) + class ArrayAccess(call: Call) extends Call(call.graph, call.id) + class FieldAccess(call: Call) extends Call(call.graph, call.id) diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/TargetTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/TargetTraversal.scala index 01a8c5a7..925bf686 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/TargetTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/TargetTraversal.scala @@ -6,14 +6,14 @@ import overflowdb.traversal.help.Doc class TargetTraversal(val traversal: Iterator[Expression]) extends AnyVal: - @Doc( - info = "Outer-most array access", - longInfo = """ + @Doc( + info = "Outer-most array access", + longInfo = """ Array access at highest level , e.g., in a(b(c)), the entire expression is returned, but not b(c) alone. """ - ) - def arrayAccess: Iterator[OpNodes.ArrayAccess] = traversal.flatMap(_.arrayAccess) + ) + def arrayAccess: Iterator[OpNodes.ArrayAccess] = traversal.flatMap(_.arrayAccess) - @Doc(info = "Returns 'pointer' in assignments of the form *(pointer) = x") - def pointer: Iterator[Expression] = traversal.flatMap(_.pointer) + @Doc(info = "Returns 'pointer' in assignments of the form *(pointer) = x") + def pointer: Iterator[Expression] = traversal.flatMap(_.pointer) diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/nodemethods/ArrayAccessMethods.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/nodemethods/ArrayAccessMethods.scala index 1ba8115a..c2909dda 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/nodemethods/ArrayAccessMethods.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/nodemethods/ArrayAccessMethods.scala @@ -6,22 +6,22 @@ import io.shiftleft.semanticcpg.language.operatorextension.OpNodes class ArrayAccessMethods(val arrayAccess: OpNodes.ArrayAccess) extends AnyVal: - def array: Expression = - arrayAccess.argument(1) + def array: Expression = + arrayAccess.argument(1) - def offset: Expression = arrayAccess.argument(2) + def offset: Expression = arrayAccess.argument(2) - def subscript: Iterator[Identifier] = - offset.ast.isIdentifier + def subscript: Iterator[Identifier] = + offset.ast.isIdentifier - def usesConstantOffset: Boolean = - offset.ast.isIdentifier.nonEmpty || { - val literalIndices = offset.ast.isLiteral.l - literalIndices.size == 1 - } + def usesConstantOffset: Boolean = + offset.ast.isIdentifier.nonEmpty || { + val literalIndices = offset.ast.isLiteral.l + literalIndices.size == 1 + } - def simpleName: Iterator[String] = - arrayAccess.array match - case id: Identifier => Iterator.single(id.name) - case _ => Iterator.empty + def simpleName: Iterator[String] = + arrayAccess.array match + case id: Identifier => Iterator.single(id.name) + case _ => Iterator.empty end ArrayAccessMethods diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/nodemethods/AssignmentMethods.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/nodemethods/AssignmentMethods.scala index 0fefd0fd..c88258ba 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/nodemethods/AssignmentMethods.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/nodemethods/AssignmentMethods.scala @@ -6,12 +6,12 @@ import io.shiftleft.semanticcpg.language.operatorextension.OpNodes class AssignmentMethods(val assignment: OpNodes.Assignment) extends AnyVal: - def target: Expression = assignment.argument(1) + def target: Expression = assignment.argument(1) - def source: Expression = - assignment.argument.size match - case 1 => assignment.argument(1) - case 2 => assignment.argument(2) - case numberOfArguments => throw new RuntimeException( - s"Assignment statement with $numberOfArguments arguments" - ) + def source: Expression = + assignment.argument.size match + case 1 => assignment.argument(1) + case 2 => assignment.argument(2) + case numberOfArguments => throw new RuntimeException( + s"Assignment statement with $numberOfArguments arguments" + ) diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/nodemethods/FieldAccessMethods.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/nodemethods/FieldAccessMethods.scala index 50c62e70..732418d6 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/nodemethods/FieldAccessMethods.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/nodemethods/FieldAccessMethods.scala @@ -6,16 +6,16 @@ import io.shiftleft.semanticcpg.language.operatorextension.OpNodes class FieldAccessMethods(val arrayAccess: OpNodes.FieldAccess) extends AnyVal: - def typeDecl: Iterator[TypeDecl] = resolveTypeDecl(arrayAccess.argument(1)) + def typeDecl: Iterator[TypeDecl] = resolveTypeDecl(arrayAccess.argument(1)) - private def resolveTypeDecl(expr: Expression): Iterator[TypeDecl] = - expr match - case x: Identifier => x.typ.referencedTypeDecl - case x: Literal => x.typ.referencedTypeDecl - case x: Call => x.fieldAccess.member.typ.referencedTypeDecl - case _ => Iterator.empty + private def resolveTypeDecl(expr: Expression): Iterator[TypeDecl] = + expr match + case x: Identifier => x.typ.referencedTypeDecl + case x: Literal => x.typ.referencedTypeDecl + case x: Call => x.fieldAccess.member.typ.referencedTypeDecl + case _ => Iterator.empty - def fieldIdentifier: Iterator[FieldIdentifier] = arrayAccess.start.argument(2).isFieldIdentifier + def fieldIdentifier: Iterator[FieldIdentifier] = arrayAccess.start.argument(2).isFieldIdentifier - def member: Option[Member] = - arrayAccess.referencedMember.headOption + def member: Option[Member] = + arrayAccess.referencedMember.headOption diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/nodemethods/OpAstNodeMethods.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/nodemethods/OpAstNodeMethods.scala index 6fa6dcd3..fde230ab 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/nodemethods/OpAstNodeMethods.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/nodemethods/OpAstNodeMethods.scala @@ -6,38 +6,38 @@ import io.shiftleft.semanticcpg.language.operatorextension.* class OpAstNodeMethods[A <: AstNode](val node: A) extends AnyVal: - def assignment: Iterator[OpNodes.Assignment] = - astDown(allAssignmentTypes).map(new OpNodes.Assignment(_)) + def assignment: Iterator[OpNodes.Assignment] = + astDown(allAssignmentTypes).map(new OpNodes.Assignment(_)) - def arithmetic: Iterator[OpNodes.Arithmetic] = - astDown(allArithmeticTypes).map(new OpNodes.Arithmetic(_)) + def arithmetic: Iterator[OpNodes.Arithmetic] = + astDown(allArithmeticTypes).map(new OpNodes.Arithmetic(_)) - def arrayAccess: Iterator[OpNodes.ArrayAccess] = - astDown(allArrayAccessTypes).map(new OpNodes.ArrayAccess(_)) + def arrayAccess: Iterator[OpNodes.ArrayAccess] = + astDown(allArrayAccessTypes).map(new OpNodes.ArrayAccess(_)) - def fieldAccess: Iterator[OpNodes.FieldAccess] = - astDown(allFieldAccessTypes).map(new OpNodes.FieldAccess(_)) + def fieldAccess: Iterator[OpNodes.FieldAccess] = + astDown(allFieldAccessTypes).map(new OpNodes.FieldAccess(_)) - private def astDown(callNames: Set[String]): Iterator[Call] = - node.ast.isCall.filter(x => callNames.contains(x.name)) + private def astDown(callNames: Set[String]): Iterator[Call] = + node.ast.isCall.filter(x => callNames.contains(x.name)) - def inAssignment: Iterator[OpNodes.Assignment] = - astUp(allAssignmentTypes) - .map(new OpNodes.Assignment(_)) + def inAssignment: Iterator[OpNodes.Assignment] = + astUp(allAssignmentTypes) + .map(new OpNodes.Assignment(_)) - def inArithmetic: Iterator[OpNodes.Arithmetic] = - astUp(allArithmeticTypes) - .map(new OpNodes.Arithmetic(_)) + def inArithmetic: Iterator[OpNodes.Arithmetic] = + astUp(allArithmeticTypes) + .map(new OpNodes.Arithmetic(_)) - def inArrayAccess: Iterator[OpNodes.ArrayAccess] = - astUp(allArrayAccessTypes) - .map(new OpNodes.ArrayAccess(_)) + def inArrayAccess: Iterator[OpNodes.ArrayAccess] = + astUp(allArrayAccessTypes) + .map(new OpNodes.ArrayAccess(_)) - def inFieldAccess: Iterator[OpNodes.FieldAccess] = - astUp(allFieldAccessTypes) - .map(new OpNodes.FieldAccess(_)) + def inFieldAccess: Iterator[OpNodes.FieldAccess] = + astUp(allFieldAccessTypes) + .map(new OpNodes.FieldAccess(_)) - private def astUp(strings: Set[String]): Iterator[Call] = - node.inAstMinusLeaf.isCall - .filter(x => strings.contains(x.name)) + private def astUp(strings: Set[String]): Iterator[Call] = + node.inAstMinusLeaf.isCall + .filter(x => strings.contains(x.name)) end OpAstNodeMethods diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/nodemethods/TargetMethods.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/nodemethods/TargetMethods.scala index f0a90fb9..1072d414 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/nodemethods/TargetMethods.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/nodemethods/TargetMethods.scala @@ -7,12 +7,12 @@ import io.shiftleft.semanticcpg.language.operatorextension.{OpNodes, allArrayAcc class TargetMethods(val expr: Expression) extends AnyVal: - def arrayAccess: Option[OpNodes.ArrayAccess] = - expr.ast.isCall - .collectFirst { case x if allArrayAccessTypes.contains(x.name) => x } - .map(new OpNodes.ArrayAccess(_)) + def arrayAccess: Option[OpNodes.ArrayAccess] = + expr.ast.isCall + .collectFirst { case x if allArrayAccessTypes.contains(x.name) => x } + .map(new OpNodes.ArrayAccess(_)) - def pointer: Option[Expression] = - Option(expr).collect { - case call: Call if call.name == Operators.indirection => call.argument(1) - } + def pointer: Option[Expression] = + Option(expr).collect { + case call: Call if call.name == Operators.indirection => call.argument(1) + } diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/package.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/package.scala index 8f5d9e12..1f73545c 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/package.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/operatorextension/package.scala @@ -4,54 +4,54 @@ import io.shiftleft.codepropertygraph.generated.Operators package object operatorextension: - /** All operators that perform both assignments and arithmetic. - */ - val assignmentAndArithmetic: Set[String] = Set( - Operators.assignmentDivision, - Operators.assignmentExponentiation, - Operators.assignmentPlus, - Operators.assignmentMinus, - Operators.assignmentModulo, - Operators.assignmentMultiplication, - Operators.preIncrement, - Operators.preDecrement, - Operators.postIncrement, - Operators.postIncrement - ) + /** All operators that perform both assignments and arithmetic. + */ + val assignmentAndArithmetic: Set[String] = Set( + Operators.assignmentDivision, + Operators.assignmentExponentiation, + Operators.assignmentPlus, + Operators.assignmentMinus, + Operators.assignmentModulo, + Operators.assignmentMultiplication, + Operators.preIncrement, + Operators.preDecrement, + Operators.postIncrement, + Operators.postIncrement + ) - /** All operators that carry out assignments. - */ - val allAssignmentTypes: Set[String] = Set( - Operators.assignment, - Operators.assignmentOr, - Operators.assignmentAnd, - Operators.assignmentXor, - Operators.assignmentArithmeticShiftRight, - Operators.assignmentLogicalShiftRight, - Operators.assignmentShiftLeft - ) ++ assignmentAndArithmetic + /** All operators that carry out assignments. + */ + val allAssignmentTypes: Set[String] = Set( + Operators.assignment, + Operators.assignmentOr, + Operators.assignmentAnd, + Operators.assignmentXor, + Operators.assignmentArithmeticShiftRight, + Operators.assignmentLogicalShiftRight, + Operators.assignmentShiftLeft + ) ++ assignmentAndArithmetic - /** All operators representing arithmetic. - */ - val allArithmeticTypes: Set[String] = Set( - Operators.addition, - Operators.subtraction, - Operators.division, - Operators.multiplication, - Operators.exponentiation, - Operators.modulo - ) ++ assignmentAndArithmetic + /** All operators representing arithmetic. + */ + val allArithmeticTypes: Set[String] = Set( + Operators.addition, + Operators.subtraction, + Operators.division, + Operators.multiplication, + Operators.exponentiation, + Operators.modulo + ) ++ assignmentAndArithmetic - /** All operators representing array accesses. - */ - val allArrayAccessTypes: Set[String] = Set( - Operators.computedMemberAccess, - Operators.indirectComputedMemberAccess, - Operators.indexAccess, - Operators.indirectIndexAccess - ) + /** All operators representing array accesses. + */ + val allArrayAccessTypes: Set[String] = Set( + Operators.computedMemberAccess, + Operators.indirectComputedMemberAccess, + Operators.indexAccess, + Operators.indirectIndexAccess + ) - /** All operators representing direct or indirect accesses to fields of data structures - */ - val allFieldAccessTypes: Set[String] = Set(Operators.fieldAccess, Operators.indirectFieldAccess) + /** All operators representing direct or indirect accesses to fields of data structures + */ + val allFieldAccessTypes: Set[String] = Set(Operators.fieldAccess, Operators.indirectFieldAccess) end operatorextension diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/package.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/package.scala index d76df0e8..06b28549 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/package.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/package.scala @@ -32,289 +32,286 @@ import overflowdb.NodeOrDetachedNode */ package object language extends operatorextension.Implicits with LowPrioImplicits with NodeTraversalImplicits: - // Implicit conversions from generated node types. We use these to add methods - // to generated node types. - - implicit def cfgNodeToAsNode(node: CfgNode): AstNodeMethods = new AstNodeMethods(node) - implicit def toExtendedNode(node: NodeOrDetachedNode): NodeMethods = new NodeMethods(node) - implicit def toExtendedStoredNode(node: StoredNode): StoredNodeMethods = - new StoredNodeMethods(node) - implicit def toAstNodeMethods(node: AstNode): AstNodeMethods = new AstNodeMethods(node) - implicit def toCfgNodeMethods(node: CfgNode): CfgNodeMethods = new CfgNodeMethods(node) - implicit def toExpressionMethods(node: Expression): ExpressionMethods = - new ExpressionMethods(node) - implicit def toMethodMethods(node: Method): MethodMethods = new MethodMethods(node) - implicit def toMethodReturnMethods(node: MethodReturn): MethodReturnMethods = - new MethodReturnMethods(node) - implicit def toCallMethods(node: Call): CallMethods = new CallMethods(node) - implicit def toMethodParamInMethods(node: MethodParameterIn): MethodParameterInMethods = - new MethodParameterInMethods(node) - implicit def toMethodParamOutMethods(node: MethodParameterOut): MethodParameterOutMethods = - new MethodParameterOutMethods(node) - implicit def toIdentifierMethods(node: Identifier): IdentifierMethods = - new IdentifierMethods(node) - implicit def toLiteralMethods(node: Literal): LiteralMethods = new LiteralMethods(node) - implicit def toLocalMethods(node: Local): LocalMethods = new LocalMethods(node) - implicit def toMethodRefMethods(node: MethodRef): MethodRefMethods = new MethodRefMethods(node) - - // Implicit conversions from Step[NodeType, Label] to corresponding Step classes. - // If you introduce a new Step-type, that is, one that inherits from `Steps[NodeType]`, - // then you need to add an implicit conversion from `Steps[NodeType]` to your type - // here. - - implicit def singleToTypeTrav[A <: Type](a: A): TypeTraversal = - new TypeTraversal(Iterator.single(a)) - implicit def iterOnceToTypeTrav[A <: Type](a: IterableOnce[A]): TypeTraversal = - new TypeTraversal(a.iterator) - - implicit def singleToTypeDeclTrav[A <: TypeDecl](a: A): TypeDeclTraversal = - new TypeDeclTraversal(Iterator.single(a)) - implicit def iterOnceToTypeDeclTrav[A <: TypeDecl](a: IterableOnce[A]): TypeDeclTraversal = - new TypeDeclTraversal(a.iterator) - - implicit def iterOnceToOriginalCallTrav[A <: Call](a: IterableOnce[A]): OriginalCall = - new OriginalCall(a.iterator) - - implicit def singleToControlStructureTrav[A <: ControlStructure](a: A) - : ControlStructureTraversal = - new ControlStructureTraversal(Iterator.single(a)) - implicit def iterOnceToControlStructureTrav[A <: ControlStructure](a: IterableOnce[A]) - : ControlStructureTraversal = - new ControlStructureTraversal(a.iterator) - - implicit def singleToIdentifierTrav[A <: Identifier](a: A): IdentifierTraversal = - new IdentifierTraversal(Iterator.single(a)) - implicit def iterOnceToIdentifierTrav[A <: Identifier](a: IterableOnce[A]) - : IdentifierTraversal = - new IdentifierTraversal(a.iterator) - - implicit def singleToAnnotationTrav[A <: Annotation](a: A): AnnotationTraversal = - new AnnotationTraversal(Iterator.single(a)) - implicit def iterOnceToAnnotationTrav[A <: Annotation](a: IterableOnce[A]) - : AnnotationTraversal = - new AnnotationTraversal(a.iterator) - - implicit def singleToDependencyTrav[A <: Dependency](a: A): DependencyTraversal = - new DependencyTraversal(Iterator.single(a)) - - implicit def iterToDependencyTrav[A <: Dependency](a: IterableOnce[A]): DependencyTraversal = - new DependencyTraversal(a.iterator) - - implicit def singleToAnnotationParameterAssignTrav[A <: AnnotationParameterAssign]( - a: A - ): AnnotationParameterAssignTraversal = - new AnnotationParameterAssignTraversal(Iterator.single(a)) - implicit def iterOnceToAnnotationParameterAssignTrav[A <: AnnotationParameterAssign]( - a: IterableOnce[A] - ): AnnotationParameterAssignTraversal = - new AnnotationParameterAssignTraversal(a.iterator) - - implicit def toMember(traversal: IterableOnce[Member]): MemberTraversal = - new MemberTraversal(traversal.iterator) - implicit def toLocal(traversal: IterableOnce[Local]): LocalTraversal = - new LocalTraversal(traversal.iterator) - implicit def toMethod(traversal: IterableOnce[Method]): OriginalMethod = - new OriginalMethod(traversal.iterator) - - implicit def singleToMethodParameterInTrav[A <: MethodParameterIn](a: A) - : MethodParameterTraversal = - new MethodParameterTraversal(Iterator.single(a)) - implicit def iterOnceToMethodParameterInTrav[A <: MethodParameterIn](a: IterableOnce[A]) - : MethodParameterTraversal = - new MethodParameterTraversal(a.iterator) - - implicit def singleToMethodParameterOutTrav[A <: MethodParameterOut](a: A) - : MethodParameterOutTraversal = - new MethodParameterOutTraversal(Iterator.single(a)) - implicit def iterOnceToMethodParameterOutTrav[A <: MethodParameterOut]( - a: IterableOnce[A] - ): MethodParameterOutTraversal = - new MethodParameterOutTraversal(a.iterator) - - implicit def iterOnceToMethodReturnTrav[A <: MethodReturn](a: IterableOnce[A]) - : MethodReturnTraversal = - new MethodReturnTraversal(a.iterator) - - implicit def singleToNamespaceTrav[A <: Namespace](a: A): NamespaceTraversal = - new NamespaceTraversal(Iterator.single(a)) - implicit def iterOnceToNamespaceTrav[A <: Namespace](a: IterableOnce[A]): NamespaceTraversal = - new NamespaceTraversal(a.iterator) - - implicit def singleToNamespaceBlockTrav[A <: NamespaceBlock](a: A): NamespaceBlockTraversal = - new NamespaceBlockTraversal(Iterator.single(a)) - implicit def iterOnceToNamespaceBlockTrav[A <: NamespaceBlock](a: IterableOnce[A]) - : NamespaceBlockTraversal = - new NamespaceBlockTraversal(a.iterator) - - implicit def singleToFileTrav[A <: File](a: A): FileTraversal = - new FileTraversal(Iterator.single(a)) - implicit def iterOnceToFileTrav[A <: File](a: IterableOnce[A]): FileTraversal = - new FileTraversal(a.iterator) - - implicit def singleToImportTrav[A <: Import](a: A): ImportTraversal = - new ImportTraversal(Iterator.single(a)) - - implicit def iterToImportTrav[A <: Import](a: IterableOnce[A]): ImportTraversal = - new ImportTraversal(a.iterator) - - // Call graph extension - implicit def singleToMethodTravCallGraphExt[A <: Method](a: A): MethodTraversal = - new MethodTraversal(Iterator.single(a)) - implicit def iterOnceToMethodTravCallGraphExt[A <: Method](a: IterableOnce[A]) - : MethodTraversal = - new MethodTraversal(a.iterator) - implicit def singleToCallTrav[A <: Call](a: A): CallTraversal = - new CallTraversal(Iterator.single(a)) - implicit def iterOnceToCallTrav[A <: Call](a: IterableOnce[A]): CallTraversal = - new CallTraversal(a.iterator) - // / Call graph extension - - // Binding extensions - implicit def singleToBindingMethodTrav[A <: Method](a: A): BindingMethodTraversal = - new BindingMethodTraversal(Iterator.single(a)) - implicit def iterOnceToBindingMethodTrav[A <: Method](a: IterableOnce[A]) - : BindingMethodTraversal = - new BindingMethodTraversal(a.iterator) - - implicit def singleToBindingTypeDeclTrav[A <: TypeDecl](a: A): BindingTypeDeclTraversal = - new BindingTypeDeclTraversal(Iterator.single(a)) - implicit def iterOnceToBindingTypeDeclTrav[A <: TypeDecl](a: IterableOnce[A]) - : BindingTypeDeclTraversal = - new BindingTypeDeclTraversal(a.iterator) - - implicit def singleToAstNodeDot[A <: AstNode](a: A): AstNodeDot[A] = - new AstNodeDot(Iterator.single(a)) - implicit def iterOnceToAstNodeDot[A <: AstNode](a: IterableOnce[A]): AstNodeDot[A] = - new AstNodeDot(a.iterator) - - implicit def singleToCfgNodeDot[A <: Method](a: A): CfgNodeDot = - new CfgNodeDot(Iterator.single(a)) - implicit def iterOnceToCfgNodeDot[A <: Method](a: IterableOnce[A]): CfgNodeDot = - new CfgNodeDot(a.iterator) - - implicit def graphToInterproceduralDot(cpg: Cpg): InterproceduralNodeDot = - new InterproceduralNodeDot(cpg) - - /** Warning: implicitly lifting `Node -> Traversal` opens a broad space with a lot of accidental - * complexity and is considered a historical accident. We only keep it around because we want - * to preserve `reachableBy(Node*)`, which unfortunately (due to type erasure) can't be an - * overload of `reachableBy(Traversal*)`. - * - * In most places you should explicitly call `Iterator.single` instead of relying on this - * implicit. + // Implicit conversions from generated node types. We use these to add methods + // to generated node types. + + implicit def cfgNodeToAsNode(node: CfgNode): AstNodeMethods = new AstNodeMethods(node) + implicit def toExtendedNode(node: NodeOrDetachedNode): NodeMethods = new NodeMethods(node) + implicit def toExtendedStoredNode(node: StoredNode): StoredNodeMethods = + new StoredNodeMethods(node) + implicit def toAstNodeMethods(node: AstNode): AstNodeMethods = new AstNodeMethods(node) + implicit def toCfgNodeMethods(node: CfgNode): CfgNodeMethods = new CfgNodeMethods(node) + implicit def toExpressionMethods(node: Expression): ExpressionMethods = + new ExpressionMethods(node) + implicit def toMethodMethods(node: Method): MethodMethods = new MethodMethods(node) + implicit def toMethodReturnMethods(node: MethodReturn): MethodReturnMethods = + new MethodReturnMethods(node) + implicit def toCallMethods(node: Call): CallMethods = new CallMethods(node) + implicit def toMethodParamInMethods(node: MethodParameterIn): MethodParameterInMethods = + new MethodParameterInMethods(node) + implicit def toMethodParamOutMethods(node: MethodParameterOut): MethodParameterOutMethods = + new MethodParameterOutMethods(node) + implicit def toIdentifierMethods(node: Identifier): IdentifierMethods = + new IdentifierMethods(node) + implicit def toLiteralMethods(node: Literal): LiteralMethods = new LiteralMethods(node) + implicit def toLocalMethods(node: Local): LocalMethods = new LocalMethods(node) + implicit def toMethodRefMethods(node: MethodRef): MethodRefMethods = new MethodRefMethods(node) + + // Implicit conversions from Step[NodeType, Label] to corresponding Step classes. + // If you introduce a new Step-type, that is, one that inherits from `Steps[NodeType]`, + // then you need to add an implicit conversion from `Steps[NodeType]` to your type + // here. + + implicit def singleToTypeTrav[A <: Type](a: A): TypeTraversal = + new TypeTraversal(Iterator.single(a)) + implicit def iterOnceToTypeTrav[A <: Type](a: IterableOnce[A]): TypeTraversal = + new TypeTraversal(a.iterator) + + implicit def singleToTypeDeclTrav[A <: TypeDecl](a: A): TypeDeclTraversal = + new TypeDeclTraversal(Iterator.single(a)) + implicit def iterOnceToTypeDeclTrav[A <: TypeDecl](a: IterableOnce[A]): TypeDeclTraversal = + new TypeDeclTraversal(a.iterator) + + implicit def iterOnceToOriginalCallTrav[A <: Call](a: IterableOnce[A]): OriginalCall = + new OriginalCall(a.iterator) + + implicit def singleToControlStructureTrav[A <: ControlStructure](a: A) + : ControlStructureTraversal = + new ControlStructureTraversal(Iterator.single(a)) + implicit def iterOnceToControlStructureTrav[A <: ControlStructure](a: IterableOnce[A]) + : ControlStructureTraversal = + new ControlStructureTraversal(a.iterator) + + implicit def singleToIdentifierTrav[A <: Identifier](a: A): IdentifierTraversal = + new IdentifierTraversal(Iterator.single(a)) + implicit def iterOnceToIdentifierTrav[A <: Identifier](a: IterableOnce[A]): IdentifierTraversal = + new IdentifierTraversal(a.iterator) + + implicit def singleToAnnotationTrav[A <: Annotation](a: A): AnnotationTraversal = + new AnnotationTraversal(Iterator.single(a)) + implicit def iterOnceToAnnotationTrav[A <: Annotation](a: IterableOnce[A]): AnnotationTraversal = + new AnnotationTraversal(a.iterator) + + implicit def singleToDependencyTrav[A <: Dependency](a: A): DependencyTraversal = + new DependencyTraversal(Iterator.single(a)) + + implicit def iterToDependencyTrav[A <: Dependency](a: IterableOnce[A]): DependencyTraversal = + new DependencyTraversal(a.iterator) + + implicit def singleToAnnotationParameterAssignTrav[A <: AnnotationParameterAssign]( + a: A + ): AnnotationParameterAssignTraversal = + new AnnotationParameterAssignTraversal(Iterator.single(a)) + implicit def iterOnceToAnnotationParameterAssignTrav[A <: AnnotationParameterAssign]( + a: IterableOnce[A] + ): AnnotationParameterAssignTraversal = + new AnnotationParameterAssignTraversal(a.iterator) + + implicit def toMember(traversal: IterableOnce[Member]): MemberTraversal = + new MemberTraversal(traversal.iterator) + implicit def toLocal(traversal: IterableOnce[Local]): LocalTraversal = + new LocalTraversal(traversal.iterator) + implicit def toMethod(traversal: IterableOnce[Method]): OriginalMethod = + new OriginalMethod(traversal.iterator) + + implicit def singleToMethodParameterInTrav[A <: MethodParameterIn](a: A) + : MethodParameterTraversal = + new MethodParameterTraversal(Iterator.single(a)) + implicit def iterOnceToMethodParameterInTrav[A <: MethodParameterIn](a: IterableOnce[A]) + : MethodParameterTraversal = + new MethodParameterTraversal(a.iterator) + + implicit def singleToMethodParameterOutTrav[A <: MethodParameterOut](a: A) + : MethodParameterOutTraversal = + new MethodParameterOutTraversal(Iterator.single(a)) + implicit def iterOnceToMethodParameterOutTrav[A <: MethodParameterOut]( + a: IterableOnce[A] + ): MethodParameterOutTraversal = + new MethodParameterOutTraversal(a.iterator) + + implicit def iterOnceToMethodReturnTrav[A <: MethodReturn](a: IterableOnce[A]) + : MethodReturnTraversal = + new MethodReturnTraversal(a.iterator) + + implicit def singleToNamespaceTrav[A <: Namespace](a: A): NamespaceTraversal = + new NamespaceTraversal(Iterator.single(a)) + implicit def iterOnceToNamespaceTrav[A <: Namespace](a: IterableOnce[A]): NamespaceTraversal = + new NamespaceTraversal(a.iterator) + + implicit def singleToNamespaceBlockTrav[A <: NamespaceBlock](a: A): NamespaceBlockTraversal = + new NamespaceBlockTraversal(Iterator.single(a)) + implicit def iterOnceToNamespaceBlockTrav[A <: NamespaceBlock](a: IterableOnce[A]) + : NamespaceBlockTraversal = + new NamespaceBlockTraversal(a.iterator) + + implicit def singleToFileTrav[A <: File](a: A): FileTraversal = + new FileTraversal(Iterator.single(a)) + implicit def iterOnceToFileTrav[A <: File](a: IterableOnce[A]): FileTraversal = + new FileTraversal(a.iterator) + + implicit def singleToImportTrav[A <: Import](a: A): ImportTraversal = + new ImportTraversal(Iterator.single(a)) + + implicit def iterToImportTrav[A <: Import](a: IterableOnce[A]): ImportTraversal = + new ImportTraversal(a.iterator) + + // Call graph extension + implicit def singleToMethodTravCallGraphExt[A <: Method](a: A): MethodTraversal = + new MethodTraversal(Iterator.single(a)) + implicit def iterOnceToMethodTravCallGraphExt[A <: Method](a: IterableOnce[A]): MethodTraversal = + new MethodTraversal(a.iterator) + implicit def singleToCallTrav[A <: Call](a: A): CallTraversal = + new CallTraversal(Iterator.single(a)) + implicit def iterOnceToCallTrav[A <: Call](a: IterableOnce[A]): CallTraversal = + new CallTraversal(a.iterator) + // / Call graph extension + + // Binding extensions + implicit def singleToBindingMethodTrav[A <: Method](a: A): BindingMethodTraversal = + new BindingMethodTraversal(Iterator.single(a)) + implicit def iterOnceToBindingMethodTrav[A <: Method](a: IterableOnce[A]) + : BindingMethodTraversal = + new BindingMethodTraversal(a.iterator) + + implicit def singleToBindingTypeDeclTrav[A <: TypeDecl](a: A): BindingTypeDeclTraversal = + new BindingTypeDeclTraversal(Iterator.single(a)) + implicit def iterOnceToBindingTypeDeclTrav[A <: TypeDecl](a: IterableOnce[A]) + : BindingTypeDeclTraversal = + new BindingTypeDeclTraversal(a.iterator) + + implicit def singleToAstNodeDot[A <: AstNode](a: A): AstNodeDot[A] = + new AstNodeDot(Iterator.single(a)) + implicit def iterOnceToAstNodeDot[A <: AstNode](a: IterableOnce[A]): AstNodeDot[A] = + new AstNodeDot(a.iterator) + + implicit def singleToCfgNodeDot[A <: Method](a: A): CfgNodeDot = + new CfgNodeDot(Iterator.single(a)) + implicit def iterOnceToCfgNodeDot[A <: Method](a: IterableOnce[A]): CfgNodeDot = + new CfgNodeDot(a.iterator) + + implicit def graphToInterproceduralDot(cpg: Cpg): InterproceduralNodeDot = + new InterproceduralNodeDot(cpg) + + /** Warning: implicitly lifting `Node -> Traversal` opens a broad space with a lot of accidental + * complexity and is considered a historical accident. We only keep it around because we want to + * preserve `reachableBy(Node*)`, which unfortunately (due to type erasure) can't be an overload + * of `reachableBy(Traversal*)`. + * + * In most places you should explicitly call `Iterator.single` instead of relying on this + * implicit. + */ + implicit def toTraversal[NodeType <: StoredNode](node: NodeType): Iterator[NodeType] = + Iterator.single(node) + + implicit def iterableOnceToSteps[A](iterableOnce: IterableOnce[A]): Steps[A] = + new Steps(iterableOnce.iterator) + + implicit def traversalToSteps[A](trav: Iterator[A]): Steps[A] = + new Steps(trav) + implicit def iterOnceToNodeSteps[A <: StoredNode](a: IterableOnce[A]): NodeSteps[A] = + new NodeSteps[A](a.iterator) + + implicit def toNewNodeTrav[NodeType <: NewNode](trav: Iterator[NodeType]) + : NewNodeSteps[NodeType] = + new NewNodeSteps[NodeType](trav) + + implicit def toNodeTypeStarters(cpg: Cpg): NodeTypeStarters = new NodeTypeStarters(cpg) + implicit def toTagTraversal(trav: Iterator[Tag]): TagTraversal = new TagTraversal(trav) + + // ~ EvalType accessors + implicit def singleToEvalTypeAccessorsLocal[A <: Local](a: A): EvalTypeAccessors[A] = + new EvalTypeAccessors[A](Iterator.single(a)) + implicit def iterOnceToEvalTypeAccessorsLocal[A <: Local](a: IterableOnce[A]) + : EvalTypeAccessors[A] = + new EvalTypeAccessors[A](a.iterator) + + implicit def singleToEvalTypeAccessorsMember[A <: Member](a: A): EvalTypeAccessors[A] = + new EvalTypeAccessors[A](Iterator.single(a)) + implicit def iterOnceToEvalTypeAccessorsMember[A <: Member](a: IterableOnce[A]) + : EvalTypeAccessors[A] = + new EvalTypeAccessors[A](a.iterator) + + implicit def singleToEvalTypeAccessorsMethod[A <: Method](a: A): EvalTypeAccessors[A] = + new EvalTypeAccessors[A](Iterator.single(a)) + implicit def iterOnceToEvalTypeAccessorsMethod[A <: Method](a: IterableOnce[A]) + : EvalTypeAccessors[A] = + new EvalTypeAccessors[A](a.iterator) + + implicit def singleToEvalTypeAccessorsParameterIn[A <: MethodParameterIn](a: A) + : EvalTypeAccessors[A] = + new EvalTypeAccessors[A](Iterator.single(a)) + implicit def iterOnceToEvalTypeAccessorsParameterIn[A <: MethodParameterIn]( + a: IterableOnce[A] + ): EvalTypeAccessors[A] = + new EvalTypeAccessors[A](a.iterator) + + implicit def singleToEvalTypeAccessorsParameterOut[A <: MethodParameterOut](a: A) + : EvalTypeAccessors[A] = + new EvalTypeAccessors[A](Iterator.single(a)) + implicit def iterOnceToEvalTypeAccessorsParameterOut[A <: MethodParameterOut]( + a: IterableOnce[A] + ): EvalTypeAccessors[A] = + new EvalTypeAccessors[A](a.iterator) + + implicit def singleToEvalTypeAccessorsMethodReturn[A <: MethodReturn](a: A) + : EvalTypeAccessors[A] = + new EvalTypeAccessors[A](Iterator.single(a)) + implicit def iterOnceToEvalTypeAccessorsMethodReturn[A <: MethodReturn](a: IterableOnce[A]) + : EvalTypeAccessors[A] = + new EvalTypeAccessors[A](a.iterator) + + implicit def singleToEvalTypeAccessorsExpression[A <: Expression](a: A): EvalTypeAccessors[A] = + new EvalTypeAccessors[A](Iterator.single(a)) + implicit def iterOnceToEvalTypeAccessorsExpression[A <: Expression](a: IterableOnce[A]) + : EvalTypeAccessors[A] = + new EvalTypeAccessors[A](a.iterator) + + // EvalType accessors ~ + + // ~ Modifier accessors + implicit def singleToModifierAccessorsMember[A <: Member](a: A): ModifierAccessors[A] = + new ModifierAccessors[A](Iterator.single(a)) + implicit def iterOnceToModifierAccessorsMember[A <: Member](a: IterableOnce[A]) + : ModifierAccessors[A] = + new ModifierAccessors[A](a.iterator) + + implicit def singleToModifierAccessorsMethod[A <: Method](a: A): ModifierAccessors[A] = + new ModifierAccessors[A](Iterator.single(a)) + implicit def iterOnceToModifierAccessorsMethod[A <: Method](a: IterableOnce[A]) + : ModifierAccessors[A] = + new ModifierAccessors[A](a.iterator) + + implicit def singleToModifierAccessorsTypeDecl[A <: TypeDecl](a: A): ModifierAccessors[A] = + new ModifierAccessors[A](Iterator.single(a)) + implicit def iterOnceToModifierAccessorsTypeDecl[A <: TypeDecl](a: IterableOnce[A]) + : ModifierAccessors[A] = + new ModifierAccessors[A](a.iterator) + // Modifier accessors ~ + + implicit class NewNodeTypeDeco[NodeType <: NewNode](val node: NodeType) extends AnyVal: + + /** Start a new traversal from this node */ - implicit def toTraversal[NodeType <: StoredNode](node: NodeType): Iterator[NodeType] = + def start: Iterator[NodeType] = Iterator.single(node) - implicit def iterableOnceToSteps[A](iterableOnce: IterableOnce[A]): Steps[A] = - new Steps(iterableOnce.iterator) - - implicit def traversalToSteps[A](trav: Iterator[A]): Steps[A] = - new Steps(trav) - implicit def iterOnceToNodeSteps[A <: StoredNode](a: IterableOnce[A]): NodeSteps[A] = - new NodeSteps[A](a.iterator) - - implicit def toNewNodeTrav[NodeType <: NewNode](trav: Iterator[NodeType]) - : NewNodeSteps[NodeType] = - new NewNodeSteps[NodeType](trav) - - implicit def toNodeTypeStarters(cpg: Cpg): NodeTypeStarters = new NodeTypeStarters(cpg) - implicit def toTagTraversal(trav: Iterator[Tag]): TagTraversal = new TagTraversal(trav) - - // ~ EvalType accessors - implicit def singleToEvalTypeAccessorsLocal[A <: Local](a: A): EvalTypeAccessors[A] = - new EvalTypeAccessors[A](Iterator.single(a)) - implicit def iterOnceToEvalTypeAccessorsLocal[A <: Local](a: IterableOnce[A]) - : EvalTypeAccessors[A] = - new EvalTypeAccessors[A](a.iterator) - - implicit def singleToEvalTypeAccessorsMember[A <: Member](a: A): EvalTypeAccessors[A] = - new EvalTypeAccessors[A](Iterator.single(a)) - implicit def iterOnceToEvalTypeAccessorsMember[A <: Member](a: IterableOnce[A]) - : EvalTypeAccessors[A] = - new EvalTypeAccessors[A](a.iterator) - - implicit def singleToEvalTypeAccessorsMethod[A <: Method](a: A): EvalTypeAccessors[A] = - new EvalTypeAccessors[A](Iterator.single(a)) - implicit def iterOnceToEvalTypeAccessorsMethod[A <: Method](a: IterableOnce[A]) - : EvalTypeAccessors[A] = - new EvalTypeAccessors[A](a.iterator) - - implicit def singleToEvalTypeAccessorsParameterIn[A <: MethodParameterIn](a: A) - : EvalTypeAccessors[A] = - new EvalTypeAccessors[A](Iterator.single(a)) - implicit def iterOnceToEvalTypeAccessorsParameterIn[A <: MethodParameterIn]( - a: IterableOnce[A] - ): EvalTypeAccessors[A] = - new EvalTypeAccessors[A](a.iterator) - - implicit def singleToEvalTypeAccessorsParameterOut[A <: MethodParameterOut](a: A) - : EvalTypeAccessors[A] = - new EvalTypeAccessors[A](Iterator.single(a)) - implicit def iterOnceToEvalTypeAccessorsParameterOut[A <: MethodParameterOut]( - a: IterableOnce[A] - ): EvalTypeAccessors[A] = - new EvalTypeAccessors[A](a.iterator) - - implicit def singleToEvalTypeAccessorsMethodReturn[A <: MethodReturn](a: A) - : EvalTypeAccessors[A] = - new EvalTypeAccessors[A](Iterator.single(a)) - implicit def iterOnceToEvalTypeAccessorsMethodReturn[A <: MethodReturn](a: IterableOnce[A]) - : EvalTypeAccessors[A] = - new EvalTypeAccessors[A](a.iterator) - - implicit def singleToEvalTypeAccessorsExpression[A <: Expression](a: A): EvalTypeAccessors[A] = - new EvalTypeAccessors[A](Iterator.single(a)) - implicit def iterOnceToEvalTypeAccessorsExpression[A <: Expression](a: IterableOnce[A]) - : EvalTypeAccessors[A] = - new EvalTypeAccessors[A](a.iterator) - - // EvalType accessors ~ - - // ~ Modifier accessors - implicit def singleToModifierAccessorsMember[A <: Member](a: A): ModifierAccessors[A] = - new ModifierAccessors[A](Iterator.single(a)) - implicit def iterOnceToModifierAccessorsMember[A <: Member](a: IterableOnce[A]) - : ModifierAccessors[A] = - new ModifierAccessors[A](a.iterator) - - implicit def singleToModifierAccessorsMethod[A <: Method](a: A): ModifierAccessors[A] = - new ModifierAccessors[A](Iterator.single(a)) - implicit def iterOnceToModifierAccessorsMethod[A <: Method](a: IterableOnce[A]) - : ModifierAccessors[A] = - new ModifierAccessors[A](a.iterator) - - implicit def singleToModifierAccessorsTypeDecl[A <: TypeDecl](a: A): ModifierAccessors[A] = - new ModifierAccessors[A](Iterator.single(a)) - implicit def iterOnceToModifierAccessorsTypeDecl[A <: TypeDecl](a: IterableOnce[A]) - : ModifierAccessors[A] = - new ModifierAccessors[A](a.iterator) - // Modifier accessors ~ - - implicit class NewNodeTypeDeco[NodeType <: NewNode](val node: NodeType) extends AnyVal: - - /** Start a new traversal from this node - */ - def start: Iterator[NodeType] = - Iterator.single(node) - - implicit def toExpression[A <: Expression](a: IterableOnce[A]): ExpressionTraversal[A] = - new ExpressionTraversal[A](a.iterator) + implicit def toExpression[A <: Expression](a: IterableOnce[A]): ExpressionTraversal[A] = + new ExpressionTraversal[A](a.iterator) end language trait LowPrioImplicits extends overflowdb.traversal.Implicits: - implicit def singleToCfgNodeTraversal[A <: CfgNode](a: A): CfgNodeTraversal[A] = - new CfgNodeTraversal[A](Iterator.single(a)) - implicit def iterOnceToCfgNodeTraversal[A <: CfgNode](a: IterableOnce[A]): CfgNodeTraversal[A] = - new CfgNodeTraversal[A](a.iterator) - - implicit def singleToAstNodeTraversal[A <: AstNode](a: A): AstNodeTraversal[A] = - new AstNodeTraversal[A](Iterator.single(a)) - implicit def iterOnceToAstNodeTraversal[A <: AstNode](a: IterableOnce[A]): AstNodeTraversal[A] = - new AstNodeTraversal[A](a.iterator) - - implicit def singleToDeclarationNodeTraversal[A <: Declaration](a: A): DeclarationTraversal[A] = - new DeclarationTraversal[A](Iterator.single(a)) - implicit def iterOnceToDeclarationNodeTraversal[A <: Declaration](a: IterableOnce[A]) - : DeclarationTraversal[A] = - new DeclarationTraversal[A](a.iterator) + implicit def singleToCfgNodeTraversal[A <: CfgNode](a: A): CfgNodeTraversal[A] = + new CfgNodeTraversal[A](Iterator.single(a)) + implicit def iterOnceToCfgNodeTraversal[A <: CfgNode](a: IterableOnce[A]): CfgNodeTraversal[A] = + new CfgNodeTraversal[A](a.iterator) + + implicit def singleToAstNodeTraversal[A <: AstNode](a: A): AstNodeTraversal[A] = + new AstNodeTraversal[A](Iterator.single(a)) + implicit def iterOnceToAstNodeTraversal[A <: AstNode](a: IterableOnce[A]): AstNodeTraversal[A] = + new AstNodeTraversal[A](a.iterator) + + implicit def singleToDeclarationNodeTraversal[A <: Declaration](a: A): DeclarationTraversal[A] = + new DeclarationTraversal[A](Iterator.single(a)) + implicit def iterOnceToDeclarationNodeTraversal[A <: Declaration](a: IterableOnce[A]) + : DeclarationTraversal[A] = + new DeclarationTraversal[A](a.iterator) diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/expressions/CallTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/expressions/CallTraversal.scala index 49f41804..dc876a56 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/expressions/CallTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/expressions/CallTraversal.scala @@ -7,35 +7,35 @@ import io.shiftleft.semanticcpg.language.* */ class CallTraversal(val traversal: Iterator[Call]) extends AnyVal: - /** Only statically dispatched calls - */ - def isStatic: Iterator[Call] = - traversal.dispatchType("STATIC_DISPATCH") - - /** Only dynamically dispatched calls - */ - def isDynamic: Iterator[Call] = - traversal.dispatchType("DYNAMIC_DISPATCH") - - /** The receiver of a call if the call has a receiver associated. - */ - def receiver: Iterator[Expression] = - traversal.flatMap(_.receiver) - - /** Arguments of the call - */ - def argument: Iterator[Expression] = - traversal.flatMap(_.argument) - - /** `i'th` arguments of the call - */ - def argument(i: Integer): Iterator[Expression] = - traversal.flatMap(_.arguments(i)) - - /** To formal method return parameter - */ - def toMethodReturn(implicit callResolver: ICallResolver): Iterator[MethodReturn] = - traversal - .flatMap(callResolver.getCalledMethodsAsTraversal) - .flatMap(_.methodReturn) + /** Only statically dispatched calls + */ + def isStatic: Iterator[Call] = + traversal.dispatchType("STATIC_DISPATCH") + + /** Only dynamically dispatched calls + */ + def isDynamic: Iterator[Call] = + traversal.dispatchType("DYNAMIC_DISPATCH") + + /** The receiver of a call if the call has a receiver associated. + */ + def receiver: Iterator[Expression] = + traversal.flatMap(_.receiver) + + /** Arguments of the call + */ + def argument: Iterator[Expression] = + traversal.flatMap(_.argument) + + /** `i'th` arguments of the call + */ + def argument(i: Integer): Iterator[Expression] = + traversal.flatMap(_.arguments(i)) + + /** To formal method return parameter + */ + def toMethodReturn(implicit callResolver: ICallResolver): Iterator[MethodReturn] = + traversal + .flatMap(callResolver.getCalledMethodsAsTraversal) + .flatMap(_.methodReturn) end CallTraversal diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/expressions/ControlStructureTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/expressions/ControlStructureTraversal.scala index 1d9c750d..8bcc88c6 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/expressions/ControlStructureTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/expressions/ControlStructureTraversal.scala @@ -6,69 +6,69 @@ import io.shiftleft.semanticcpg.language.* import overflowdb.traversal.help.Doc object ControlStructureTraversal: - val secondChildIndex = 2 - val thirdChildIndex = 3 + val secondChildIndex = 2 + val thirdChildIndex = 3 class ControlStructureTraversal(val traversal: Iterator[ControlStructure]) extends AnyVal: - import ControlStructureTraversal.* + import ControlStructureTraversal.* - @Doc(info = "The condition associated with this control structure") - def condition: Iterator[Expression] = - traversal.flatMap(_.conditionOut).collectAll[Expression] + @Doc(info = "The condition associated with this control structure") + def condition: Iterator[Expression] = + traversal.flatMap(_.conditionOut).collectAll[Expression] - @Doc(info = "Control structures where condition.code matches regex") - def condition(regex: String): Iterator[ControlStructure] = - traversal.where(_.condition.code(regex)) + @Doc(info = "Control structures where condition.code matches regex") + def condition(regex: String): Iterator[ControlStructure] = + traversal.where(_.condition.code(regex)) - @Doc(info = "Sub tree taken when condition evaluates to true") - def whenTrue: Iterator[AstNode] = - traversal.out.has(Properties.ORDER, secondChildIndex: Int).cast[AstNode] + @Doc(info = "Sub tree taken when condition evaluates to true") + def whenTrue: Iterator[AstNode] = + traversal.out.has(Properties.ORDER, secondChildIndex: Int).cast[AstNode] - @Doc(info = "Sub tree taken when condition evaluates to false") - def whenFalse: Iterator[AstNode] = - traversal.out.has(Properties.ORDER, thirdChildIndex).cast[AstNode] + @Doc(info = "Sub tree taken when condition evaluates to false") + def whenFalse: Iterator[AstNode] = + traversal.out.has(Properties.ORDER, thirdChildIndex).cast[AstNode] - @Doc(info = "Only `Try` control structures") - def isTry: Iterator[ControlStructure] = - traversal.controlStructureTypeExact(ControlStructureTypes.TRY) + @Doc(info = "Only `Try` control structures") + def isTry: Iterator[ControlStructure] = + traversal.controlStructureTypeExact(ControlStructureTypes.TRY) - @Doc(info = "Only `If` control structures") - def isIf: Iterator[ControlStructure] = - traversal.controlStructureTypeExact(ControlStructureTypes.IF) + @Doc(info = "Only `If` control structures") + def isIf: Iterator[ControlStructure] = + traversal.controlStructureTypeExact(ControlStructureTypes.IF) - @Doc(info = "Only `Else` control structures") - def isElse: Iterator[ControlStructure] = - traversal.controlStructureTypeExact(ControlStructureTypes.ELSE) + @Doc(info = "Only `Else` control structures") + def isElse: Iterator[ControlStructure] = + traversal.controlStructureTypeExact(ControlStructureTypes.ELSE) - @Doc(info = "Only `Switch` control structures") - def isSwitch: Iterator[ControlStructure] = - traversal.controlStructureTypeExact(ControlStructureTypes.SWITCH) + @Doc(info = "Only `Switch` control structures") + def isSwitch: Iterator[ControlStructure] = + traversal.controlStructureTypeExact(ControlStructureTypes.SWITCH) - @Doc(info = "Only `Do` control structures") - def isDo: Iterator[ControlStructure] = - traversal.controlStructureTypeExact(ControlStructureTypes.DO) + @Doc(info = "Only `Do` control structures") + def isDo: Iterator[ControlStructure] = + traversal.controlStructureTypeExact(ControlStructureTypes.DO) - @Doc(info = "Only `For` control structures") - def isFor: Iterator[ControlStructure] = - traversal.controlStructureTypeExact(ControlStructureTypes.FOR) + @Doc(info = "Only `For` control structures") + def isFor: Iterator[ControlStructure] = + traversal.controlStructureTypeExact(ControlStructureTypes.FOR) - @Doc(info = "Only `While` control structures") - def isWhile: Iterator[ControlStructure] = - traversal.controlStructureTypeExact(ControlStructureTypes.WHILE) + @Doc(info = "Only `While` control structures") + def isWhile: Iterator[ControlStructure] = + traversal.controlStructureTypeExact(ControlStructureTypes.WHILE) - @Doc(info = "Only `Goto` control structures") - def isGoto: Iterator[ControlStructure] = - traversal.controlStructureTypeExact(ControlStructureTypes.GOTO) + @Doc(info = "Only `Goto` control structures") + def isGoto: Iterator[ControlStructure] = + traversal.controlStructureTypeExact(ControlStructureTypes.GOTO) - @Doc(info = "Only `Break` control structures") - def isBreak: Iterator[ControlStructure] = - traversal.controlStructureTypeExact(ControlStructureTypes.BREAK) + @Doc(info = "Only `Break` control structures") + def isBreak: Iterator[ControlStructure] = + traversal.controlStructureTypeExact(ControlStructureTypes.BREAK) - @Doc(info = "Only `Continue` control structures") - def isContinue: Iterator[ControlStructure] = - traversal.controlStructureTypeExact(ControlStructureTypes.CONTINUE) + @Doc(info = "Only `Continue` control structures") + def isContinue: Iterator[ControlStructure] = + traversal.controlStructureTypeExact(ControlStructureTypes.CONTINUE) - @Doc(info = "Only `Throw` control structures") - def isThrow: Iterator[ControlStructure] = - traversal.controlStructureTypeExact(ControlStructureTypes.THROW) + @Doc(info = "Only `Throw` control structures") + def isThrow: Iterator[ControlStructure] = + traversal.controlStructureTypeExact(ControlStructureTypes.THROW) end ControlStructureTraversal diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/expressions/IdentifierTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/expressions/IdentifierTraversal.scala index bfc40a6c..2c42df8d 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/expressions/IdentifierTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/expressions/IdentifierTraversal.scala @@ -7,7 +7,7 @@ import io.shiftleft.semanticcpg.language.toTraversalSugarExt */ class IdentifierTraversal(val traversal: Iterator[Identifier]) extends AnyVal: - /** Traverse to all declarations of this identifier - */ - def refsTo: Iterator[Declaration] = - traversal.flatMap(_.refOut).cast[Declaration] + /** Traverse to all declarations of this identifier + */ + def refsTo: Iterator[Declaration] = + traversal.flatMap(_.refOut).cast[Declaration] diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/expressions/generalizations/AstNodeTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/expressions/generalizations/AstNodeTraversal.scala index 1f64a3d5..634ee14a 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/expressions/generalizations/AstNodeTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/expressions/generalizations/AstNodeTraversal.scala @@ -9,217 +9,217 @@ import overflowdb.traversal.help.Doc @help.Traversal(elementType = classOf[AstNode]) class AstNodeTraversal[A <: AstNode](val traversal: Iterator[A]) extends AnyVal: - /** Nodes of the AST rooted in this node, including the node itself. - */ - @Doc(info = "All nodes of the abstract syntax tree") - def ast: Iterator[AstNode] = - traversal.repeat(_.out(EdgeTypes.AST))(_.emit).cast[AstNode] - - /** All nodes of the abstract syntax tree rooted in this node, which match `predicate`. - * Equivalent of `match` in the original CPG paper. - */ - def ast(predicate: AstNode => Boolean): Iterator[AstNode] = - ast.filter(predicate) - - def containsCallTo(regex: String): Iterator[A] = - traversal.filter(_.ast.isCall.name(regex).nonEmpty) - - @Doc(info = "Depth of the abstract syntax tree") - def depth: Iterator[Int] = - traversal.map(_.depth) - - def depth(p: AstNode => Boolean): Iterator[Int] = - traversal.map(_.depth(p)) - - def isCallTo(regex: String): Iterator[Call] = - isCall.name(regex) - - /** Nodes of the AST rooted in this node, minus the node itself - */ - def astMinusRoot: Iterator[AstNode] = - traversal.repeat(_.out(EdgeTypes.AST))(_.emitAllButFirst).cast[AstNode] - - /** Direct children of node in the AST. Siblings are ordered by their `order` fields - */ - def astChildren: Iterator[AstNode] = - traversal.flatMap(_.astChildren).sortBy(_.order).iterator - - /** Parent AST node - */ - def astParent: Iterator[AstNode] = - traversal.in(EdgeTypes.AST).cast[AstNode] - - /** Siblings of this node in the AST, ordered by their `order` fields - */ - def astSiblings: Iterator[AstNode] = - traversal.flatMap(_.astSiblings) - - /** Traverses up the AST and returns the first block node. - */ - def parentBlock: Iterator[Block] = - traversal.repeat(_.in(EdgeTypes.AST))(_.emit.until(_.hasLabel(NodeTypes.BLOCK))).collectAll[ - Block - ] - - /** Nodes of the AST obtained by expanding AST edges backwards until the method root is reached - */ - def inAst: Iterator[AstNode] = - inAst(null) - - /** Nodes of the AST obtained by expanding AST edges backwards until the method root is reached, - * minus this node - */ - def inAstMinusLeaf: Iterator[AstNode] = - inAstMinusLeaf(null) - - /** Nodes of the AST obtained by expanding AST edges backwards until `root` or the method root - * is reached - */ - def inAst(root: AstNode): Iterator[AstNode] = - traversal - .repeat(_.in(EdgeTypes.AST))( - _.emit - .until(_.or( - _.hasLabel(NodeTypes.METHOD), - _.filter(n => root != null && root == n) - )) - ) - .cast[AstNode] - - /** Nodes of the AST obtained by expanding AST edges backwards until `root` or the method root - * is reached, minus this node - */ - def inAstMinusLeaf(root: AstNode): Iterator[AstNode] = - traversal - .repeat(_.in(EdgeTypes.AST))( - _.emitAllButFirst - .until(_.or( - _.hasLabel(NodeTypes.METHOD), - _.filter(n => root != null && root == n) - )) - ) - .cast[AstNode] - - /** Traverse only to those AST nodes that are also control flow graph nodes - */ - def isCfgNode: Iterator[CfgNode] = - traversal.collectAll[CfgNode] - - def isAnnotation: Iterator[Annotation] = - traversal.collectAll[Annotation] - - def isAnnotationLiteral: Iterator[AnnotationLiteral] = - traversal.collectAll[AnnotationLiteral] - - def isArrayInitializer: Iterator[ArrayInitializer] = - traversal.collectAll[ArrayInitializer] - - /** Traverse only to those AST nodes that are blocks - */ - def isBlock: Iterator[Block] = - traversal.collectAll[Block] - - /** Traverse only to those AST nodes that are control structures - */ - def isControlStructure: Iterator[ControlStructure] = - traversal.collectAll[ControlStructure] - - /** Traverse only to AST nodes that are expressions - */ - def isExpression: Iterator[Expression] = - traversal.collectAll[Expression] - - /** Traverse only to AST nodes that are calls - */ - def isCall: Iterator[Call] = - traversal.collectAll[Call] - - /** Cast to call if applicable and filter on call code `calleeRegex` - */ - def isCall(calleeRegex: String): Iterator[Call] = - isCall.where(_.code(calleeRegex)) - - /** Traverse only to AST nodes that are literals - */ - def isLiteral: Iterator[Literal] = - traversal.collectAll[Literal] - - def isLocal: Iterator[Local] = - traversal.collectAll[Local] - - /** Traverse only to AST nodes that are identifier - */ - def isIdentifier: Iterator[Identifier] = - traversal.collectAll[Identifier] - - /** Traverse only to AST nodes that are IMPORT nodes - */ - def isImport: Iterator[Import] = - traversal.collectAll[Import] - - /** Traverse only to FILE AST nodes - */ - def isFile: Iterator[File] = - traversal.collectAll[File] - - /** Traverse only to AST nodes that are field identifier - */ - def isFieldIdentifier: Iterator[FieldIdentifier] = - traversal.collectAll[FieldIdentifier] - - /** Traverse only to AST nodes that are return nodes - */ - def isReturn: Iterator[Return] = - traversal.collectAll[Return] - - /** Traverse only to AST nodes that are MEMBER - */ - def isMember: Iterator[Member] = - traversal.collectAll[Member] - - /** Traverse only to AST nodes that are method reference - */ - def isMethodRef: Iterator[MethodRef] = - traversal.collectAll[MethodRef] - - /** Traverse only to AST nodes that are type reference - */ - def isTypeRef: Iterator[TypeRef] = - traversal.collectAll[TypeRef] - - /** Traverse only to AST nodes that are METHOD - */ - def isMethod: Iterator[Method] = - traversal.collectAll[Method] - - /** Traverse only to AST nodes that are MODIFIER - */ - def isModifier: Iterator[Modifier] = - traversal.collectAll[Modifier] - - /** Traverse only to AST nodes that are NAMESPACE_BLOCK - */ - def isNamespaceBlock: Iterator[NamespaceBlock] = - traversal.collectAll[NamespaceBlock] - - /** Traverse only to AST nodes that are METHOD_PARAMETER_IN - */ - def isParameter: Iterator[MethodParameterIn] = - traversal.collectAll[MethodParameterIn] - - /** Traverse only to AST nodes that are TemplateDom nodes - */ - def isTemplateDom: Iterator[TemplateDom] = - traversal.collectAll[TemplateDom] - - /** Traverse only to AST nodes that are TYPE_DECL - */ - def isTypeDecl: Iterator[TypeDecl] = - traversal.collectAll[TypeDecl] - - def walkAstUntilReaching(labels: List[String]): Iterator[StoredNode] = - traversal - .repeat(_.out(EdgeTypes.AST))(_.emitAllButFirst.until(_.hasLabel(labels*))) - .dedup - .cast[StoredNode] + /** Nodes of the AST rooted in this node, including the node itself. + */ + @Doc(info = "All nodes of the abstract syntax tree") + def ast: Iterator[AstNode] = + traversal.repeat(_.out(EdgeTypes.AST))(_.emit).cast[AstNode] + + /** All nodes of the abstract syntax tree rooted in this node, which match `predicate`. Equivalent + * of `match` in the original CPG paper. + */ + def ast(predicate: AstNode => Boolean): Iterator[AstNode] = + ast.filter(predicate) + + def containsCallTo(regex: String): Iterator[A] = + traversal.filter(_.ast.isCall.name(regex).nonEmpty) + + @Doc(info = "Depth of the abstract syntax tree") + def depth: Iterator[Int] = + traversal.map(_.depth) + + def depth(p: AstNode => Boolean): Iterator[Int] = + traversal.map(_.depth(p)) + + def isCallTo(regex: String): Iterator[Call] = + isCall.name(regex) + + /** Nodes of the AST rooted in this node, minus the node itself + */ + def astMinusRoot: Iterator[AstNode] = + traversal.repeat(_.out(EdgeTypes.AST))(_.emitAllButFirst).cast[AstNode] + + /** Direct children of node in the AST. Siblings are ordered by their `order` fields + */ + def astChildren: Iterator[AstNode] = + traversal.flatMap(_.astChildren).sortBy(_.order).iterator + + /** Parent AST node + */ + def astParent: Iterator[AstNode] = + traversal.in(EdgeTypes.AST).cast[AstNode] + + /** Siblings of this node in the AST, ordered by their `order` fields + */ + def astSiblings: Iterator[AstNode] = + traversal.flatMap(_.astSiblings) + + /** Traverses up the AST and returns the first block node. + */ + def parentBlock: Iterator[Block] = + traversal.repeat(_.in(EdgeTypes.AST))(_.emit.until(_.hasLabel(NodeTypes.BLOCK))).collectAll[ + Block + ] + + /** Nodes of the AST obtained by expanding AST edges backwards until the method root is reached + */ + def inAst: Iterator[AstNode] = + inAst(null) + + /** Nodes of the AST obtained by expanding AST edges backwards until the method root is reached, + * minus this node + */ + def inAstMinusLeaf: Iterator[AstNode] = + inAstMinusLeaf(null) + + /** Nodes of the AST obtained by expanding AST edges backwards until `root` or the method root is + * reached + */ + def inAst(root: AstNode): Iterator[AstNode] = + traversal + .repeat(_.in(EdgeTypes.AST))( + _.emit + .until(_.or( + _.hasLabel(NodeTypes.METHOD), + _.filter(n => root != null && root == n) + )) + ) + .cast[AstNode] + + /** Nodes of the AST obtained by expanding AST edges backwards until `root` or the method root is + * reached, minus this node + */ + def inAstMinusLeaf(root: AstNode): Iterator[AstNode] = + traversal + .repeat(_.in(EdgeTypes.AST))( + _.emitAllButFirst + .until(_.or( + _.hasLabel(NodeTypes.METHOD), + _.filter(n => root != null && root == n) + )) + ) + .cast[AstNode] + + /** Traverse only to those AST nodes that are also control flow graph nodes + */ + def isCfgNode: Iterator[CfgNode] = + traversal.collectAll[CfgNode] + + def isAnnotation: Iterator[Annotation] = + traversal.collectAll[Annotation] + + def isAnnotationLiteral: Iterator[AnnotationLiteral] = + traversal.collectAll[AnnotationLiteral] + + def isArrayInitializer: Iterator[ArrayInitializer] = + traversal.collectAll[ArrayInitializer] + + /** Traverse only to those AST nodes that are blocks + */ + def isBlock: Iterator[Block] = + traversal.collectAll[Block] + + /** Traverse only to those AST nodes that are control structures + */ + def isControlStructure: Iterator[ControlStructure] = + traversal.collectAll[ControlStructure] + + /** Traverse only to AST nodes that are expressions + */ + def isExpression: Iterator[Expression] = + traversal.collectAll[Expression] + + /** Traverse only to AST nodes that are calls + */ + def isCall: Iterator[Call] = + traversal.collectAll[Call] + + /** Cast to call if applicable and filter on call code `calleeRegex` + */ + def isCall(calleeRegex: String): Iterator[Call] = + isCall.where(_.code(calleeRegex)) + + /** Traverse only to AST nodes that are literals + */ + def isLiteral: Iterator[Literal] = + traversal.collectAll[Literal] + + def isLocal: Iterator[Local] = + traversal.collectAll[Local] + + /** Traverse only to AST nodes that are identifier + */ + def isIdentifier: Iterator[Identifier] = + traversal.collectAll[Identifier] + + /** Traverse only to AST nodes that are IMPORT nodes + */ + def isImport: Iterator[Import] = + traversal.collectAll[Import] + + /** Traverse only to FILE AST nodes + */ + def isFile: Iterator[File] = + traversal.collectAll[File] + + /** Traverse only to AST nodes that are field identifier + */ + def isFieldIdentifier: Iterator[FieldIdentifier] = + traversal.collectAll[FieldIdentifier] + + /** Traverse only to AST nodes that are return nodes + */ + def isReturn: Iterator[Return] = + traversal.collectAll[Return] + + /** Traverse only to AST nodes that are MEMBER + */ + def isMember: Iterator[Member] = + traversal.collectAll[Member] + + /** Traverse only to AST nodes that are method reference + */ + def isMethodRef: Iterator[MethodRef] = + traversal.collectAll[MethodRef] + + /** Traverse only to AST nodes that are type reference + */ + def isTypeRef: Iterator[TypeRef] = + traversal.collectAll[TypeRef] + + /** Traverse only to AST nodes that are METHOD + */ + def isMethod: Iterator[Method] = + traversal.collectAll[Method] + + /** Traverse only to AST nodes that are MODIFIER + */ + def isModifier: Iterator[Modifier] = + traversal.collectAll[Modifier] + + /** Traverse only to AST nodes that are NAMESPACE_BLOCK + */ + def isNamespaceBlock: Iterator[NamespaceBlock] = + traversal.collectAll[NamespaceBlock] + + /** Traverse only to AST nodes that are METHOD_PARAMETER_IN + */ + def isParameter: Iterator[MethodParameterIn] = + traversal.collectAll[MethodParameterIn] + + /** Traverse only to AST nodes that are TemplateDom nodes + */ + def isTemplateDom: Iterator[TemplateDom] = + traversal.collectAll[TemplateDom] + + /** Traverse only to AST nodes that are TYPE_DECL + */ + def isTypeDecl: Iterator[TypeDecl] = + traversal.collectAll[TypeDecl] + + def walkAstUntilReaching(labels: List[String]): Iterator[StoredNode] = + traversal + .repeat(_.out(EdgeTypes.AST))(_.emitAllButFirst.until(_.hasLabel(labels*))) + .dedup + .cast[StoredNode] end AstNodeTraversal diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/expressions/generalizations/CfgNodeTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/expressions/generalizations/CfgNodeTraversal.scala index bc64bd8a..4a89ff84 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/expressions/generalizations/CfgNodeTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/expressions/generalizations/CfgNodeTraversal.scala @@ -8,85 +8,85 @@ import overflowdb.traversal.help.Doc @help.Traversal(elementType = classOf[CfgNode]) class CfgNodeTraversal[A <: CfgNode](val traversal: Iterator[A]) extends AnyVal: - /** Textual representation of CFG node - */ - @Doc(info = "Textual representation of CFG node") - def repr: Iterator[String] = - traversal.map(_.repr) - - /** Traverse to enclosing method - */ - def method: Iterator[Method] = - traversal.map( - _.method - ) // refers to `semanticcpg.language.nodemethods.CfgNodeMethods.method` - - /** Traverse to next expression in CFG. - */ - - @Doc(info = "Nodes directly reachable via outgoing CFG edges") - def cfgNext: Iterator[CfgNode] = - traversal._cfgOut - .filterNot(_.isInstanceOf[MethodReturn]) - .cast[CfgNode] - - /** Traverse to previous expression in CFG. - */ - @Doc(info = "Nodes directly reachable via incoming CFG edges") - def cfgPrev: Iterator[CfgNode] = - traversal._cfgIn - .filterNot(_.isInstanceOf[MethodReturn]) - .cast[CfgNode] - - /** All nodes reachable in the CFG by up to n forward expansions - */ - def cfgNext(n: Int): Iterator[CfgNode] = traversal.flatMap(_.cfgNext(n)) - - /** All nodes reachable in the CFG by up to n backward expansions - */ - def cfgPrev(n: Int): Iterator[CfgNode] = traversal.flatMap(_.cfgPrev(n)) - - /** Recursively determine all nodes on which any of the nodes in this traversal are control - * dependent - */ - @Doc(info = "All nodes on which this node is control dependent") - def controlledBy: Iterator[CfgNode] = - traversal.flatMap(_.controlledBy) - - /** Recursively determine all nodes which are control dependent on this node - */ - @Doc(info = "All nodes control dependent on this node") - def controls: Iterator[CfgNode] = - traversal.flatMap(_.controls) - - /** Recursively determine all nodes by which this node is dominated - */ - @Doc(info = "All nodes by which this node is dominated") - def dominatedBy: Iterator[CfgNode] = - traversal.flatMap(_.dominatedBy) - - /** Recursively determine all nodes which this node dominates - */ - @Doc(info = "All nodes that are dominated by this node") - def dominates: Iterator[CfgNode] = - traversal.flatMap(_.dominates) - - /** Recursively determine all nodes by which this node is post dominated - */ - @Doc(info = "All nodes by which this node is post dominated") - def postDominatedBy: Iterator[CfgNode] = - traversal.flatMap(_.postDominatedBy) - - /** Recursively determine all nodes which this node post dominates - */ - @Doc(info = "All nodes that are post dominated by this node") - def postDominates: Iterator[CfgNode] = - traversal.flatMap(_.postDominates) - - /** Obtain hexadecimal string representation of lineNumber field. - */ - @Doc(info = "Address of the code (for binary code)") - def address: Iterator[Option[String]] = - traversal.map(_.address) + /** Textual representation of CFG node + */ + @Doc(info = "Textual representation of CFG node") + def repr: Iterator[String] = + traversal.map(_.repr) + + /** Traverse to enclosing method + */ + def method: Iterator[Method] = + traversal.map( + _.method + ) // refers to `semanticcpg.language.nodemethods.CfgNodeMethods.method` + + /** Traverse to next expression in CFG. + */ + + @Doc(info = "Nodes directly reachable via outgoing CFG edges") + def cfgNext: Iterator[CfgNode] = + traversal._cfgOut + .filterNot(_.isInstanceOf[MethodReturn]) + .cast[CfgNode] + + /** Traverse to previous expression in CFG. + */ + @Doc(info = "Nodes directly reachable via incoming CFG edges") + def cfgPrev: Iterator[CfgNode] = + traversal._cfgIn + .filterNot(_.isInstanceOf[MethodReturn]) + .cast[CfgNode] + + /** All nodes reachable in the CFG by up to n forward expansions + */ + def cfgNext(n: Int): Iterator[CfgNode] = traversal.flatMap(_.cfgNext(n)) + + /** All nodes reachable in the CFG by up to n backward expansions + */ + def cfgPrev(n: Int): Iterator[CfgNode] = traversal.flatMap(_.cfgPrev(n)) + + /** Recursively determine all nodes on which any of the nodes in this traversal are control + * dependent + */ + @Doc(info = "All nodes on which this node is control dependent") + def controlledBy: Iterator[CfgNode] = + traversal.flatMap(_.controlledBy) + + /** Recursively determine all nodes which are control dependent on this node + */ + @Doc(info = "All nodes control dependent on this node") + def controls: Iterator[CfgNode] = + traversal.flatMap(_.controls) + + /** Recursively determine all nodes by which this node is dominated + */ + @Doc(info = "All nodes by which this node is dominated") + def dominatedBy: Iterator[CfgNode] = + traversal.flatMap(_.dominatedBy) + + /** Recursively determine all nodes which this node dominates + */ + @Doc(info = "All nodes that are dominated by this node") + def dominates: Iterator[CfgNode] = + traversal.flatMap(_.dominates) + + /** Recursively determine all nodes by which this node is post dominated + */ + @Doc(info = "All nodes by which this node is post dominated") + def postDominatedBy: Iterator[CfgNode] = + traversal.flatMap(_.postDominatedBy) + + /** Recursively determine all nodes which this node post dominates + */ + @Doc(info = "All nodes that are post dominated by this node") + def postDominates: Iterator[CfgNode] = + traversal.flatMap(_.postDominates) + + /** Obtain hexadecimal string representation of lineNumber field. + */ + @Doc(info = "Address of the code (for binary code)") + def address: Iterator[Option[String]] = + traversal.map(_.address) end CfgNodeTraversal diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/expressions/generalizations/DeclarationTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/expressions/generalizations/DeclarationTraversal.scala index c4018300..12d5c1f1 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/expressions/generalizations/DeclarationTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/expressions/generalizations/DeclarationTraversal.scala @@ -15,17 +15,17 @@ import overflowdb.traversal.help class DeclarationTraversal[NodeType <: Declaration](val traversal: Iterator[NodeType]) extends AnyVal: - /** The closure binding node referenced by this declaration - */ - def closureBinding: Iterator[ClosureBinding] = - traversal.flatMap(_._refIn).collectAll[ClosureBinding] + /** The closure binding node referenced by this declaration + */ + def closureBinding: Iterator[ClosureBinding] = + traversal.flatMap(_._refIn).collectAll[ClosureBinding] - /** Methods that capture this declaration - */ - def capturedByMethodRef: Iterator[MethodRef] = - closureBinding.flatMap(_._captureIn).collectAll[MethodRef] + /** Methods that capture this declaration + */ + def capturedByMethodRef: Iterator[MethodRef] = + closureBinding.flatMap(_._captureIn).collectAll[MethodRef] - /** Types that capture this declaration - */ - def capturedByTypeRef: Iterator[TypeRef] = - closureBinding.flatMap(_._captureIn).collectAll[TypeRef] + /** Types that capture this declaration + */ + def capturedByTypeRef: Iterator[TypeRef] = + closureBinding.flatMap(_._captureIn).collectAll[TypeRef] diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/expressions/generalizations/ExpressionTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/expressions/generalizations/ExpressionTraversal.scala index a297ad14..ed7dc3e4 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/expressions/generalizations/ExpressionTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/expressions/generalizations/ExpressionTraversal.scala @@ -8,67 +8,67 @@ import io.shiftleft.semanticcpg.language.* */ class ExpressionTraversal[NodeType <: Expression](val traversal: Iterator[NodeType]) extends AnyVal: - /** Traverse to it's parent expression (e.g. call or return) by following the incoming AST It's - * continuing it's walk until it hits an expression that's not a generic "member access - * operation", e.g., ".memberAccess". - */ - def parentExpression: Iterator[Expression] = - traversal.flatMap(_.parentExpression) + /** Traverse to it's parent expression (e.g. call or return) by following the incoming AST It's + * continuing it's walk until it hits an expression that's not a generic "member access + * operation", e.g., ".memberAccess". + */ + def parentExpression: Iterator[Expression] = + traversal.flatMap(_.parentExpression) - /** Traverse to enclosing expression - */ - def expressionUp: Iterator[Expression] = - traversal.flatMap(_.expressionUp) + /** Traverse to enclosing expression + */ + def expressionUp: Iterator[Expression] = + traversal.flatMap(_.expressionUp) - /** Traverse to sub expressions - */ - def expressionDown: Iterator[Expression] = - traversal.flatMap(_.expressionDown) + /** Traverse to sub expressions + */ + def expressionDown: Iterator[Expression] = + traversal.flatMap(_.expressionDown) - /** If the expression is used as receiver for a call, this traverses to the call. - */ - def receivedCall: Iterator[Call] = - traversal.flatMap(_.receivedCall) + /** If the expression is used as receiver for a call, this traverses to the call. + */ + def receivedCall: Iterator[Call] = + traversal.flatMap(_.receivedCall) - /** Only those expressions which are (direct) arguments of a call - */ - def isArgument: Iterator[Expression] = - traversal.flatMap(_.isArgument) + /** Only those expressions which are (direct) arguments of a call + */ + def isArgument: Iterator[Expression] = + traversal.flatMap(_.isArgument) - /** Traverse to surrounding call - */ - def inCall: Iterator[Call] = - traversal.flatMap(_.inCall) + /** Traverse to surrounding call + */ + def inCall: Iterator[Call] = + traversal.flatMap(_.inCall) - /** Traverse to surrounding call - */ - @deprecated("Use inCall") - def call: Iterator[Call] = - inCall + /** Traverse to surrounding call + */ + @deprecated("Use inCall") + def call: Iterator[Call] = + inCall - /** Traverse to related parameter - */ - @deprecated("", "October 2019") - def toParameter(implicit callResolver: ICallResolver): Iterator[MethodParameterIn] = parameter + /** Traverse to related parameter + */ + @deprecated("", "October 2019") + def toParameter(implicit callResolver: ICallResolver): Iterator[MethodParameterIn] = parameter - /** Traverse to related parameter, if the expression is an argument to a call and the call can - * be resolved. - */ - def parameter(implicit callResolver: ICallResolver): Iterator[MethodParameterIn] = - traversal.flatMap(_.parameter) + /** Traverse to related parameter, if the expression is an argument to a call and the call can be + * resolved. + */ + def parameter(implicit callResolver: ICallResolver): Iterator[MethodParameterIn] = + traversal.flatMap(_.parameter) - /** Traverse to enclosing method - */ - def method: Iterator[Method] = - traversal._containsIn - .flatMap { - case x: Method => x.start - case x: TypeDecl => x.astParent - } - .collectAll[Method] + /** Traverse to enclosing method + */ + def method: Iterator[Method] = + traversal._containsIn + .flatMap { + case x: Method => x.start + case x: TypeDecl => x.astParent + } + .collectAll[Method] - /** Traverse to expression evaluation type - */ - def typ: Iterator[Type] = - traversal.flatMap(_._evalTypeOut).cast[Type] + /** Traverse to expression evaluation type + */ + def typ: Iterator[Type] = + traversal.flatMap(_._evalTypeOut).cast[Type] end ExpressionTraversal diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/propertyaccessors/EvalTypeAccessors.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/propertyaccessors/EvalTypeAccessors.scala index c3133f55..83f5610e 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/propertyaccessors/EvalTypeAccessors.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/propertyaccessors/EvalTypeAccessors.scala @@ -6,36 +6,36 @@ import io.shiftleft.semanticcpg.language.* class EvalTypeAccessors[A <: AstNode](val traversal: Iterator[A]) extends AnyVal: - def evalType: Iterator[String] = - evalType(traversal) - - def evalType(regex: String): Iterator[A] = - traversal.where(evalType(_).filter(_.matches(regex))) - - def evalType(regexes: String*): Iterator[A] = - if regexes.isEmpty then Iterator.empty - else - val regexes0 = regexes.map(_.r).toSet - traversal.where(evalType(_).filter(value => regexes0.exists(_.matches(value)))) - - def evalTypeExact(value: String): Iterator[A] = - traversal.where(evalType(_).filter(_ == value)) - - def evalTypeExact(values: String*): Iterator[A] = - if values.isEmpty then Iterator.empty - else - val valuesSet = values.to(Set) - traversal.where(evalType(_).filter(valuesSet.contains)) - - def evalTypeNot(regex: String): Iterator[A] = - traversal.where(evalType(_).filterNot(_.matches(regex))) - - def evalTypeNot(regexes: String*): Iterator[A] = - if regexes.isEmpty then Iterator.empty - else - val regexes0 = regexes.map(_.r).toSet - traversal.where(evalType(_).filter(value => !regexes0.exists(_.matches(value)))) - - private def evalType(traversal: Iterator[A]): Iterator[String] = - traversal.flatMap(_._evalTypeOut).flatMap(_._refOut).property(Properties.FULL_NAME) + def evalType: Iterator[String] = + evalType(traversal) + + def evalType(regex: String): Iterator[A] = + traversal.where(evalType(_).filter(_.matches(regex))) + + def evalType(regexes: String*): Iterator[A] = + if regexes.isEmpty then Iterator.empty + else + val regexes0 = regexes.map(_.r).toSet + traversal.where(evalType(_).filter(value => regexes0.exists(_.matches(value)))) + + def evalTypeExact(value: String): Iterator[A] = + traversal.where(evalType(_).filter(_ == value)) + + def evalTypeExact(values: String*): Iterator[A] = + if values.isEmpty then Iterator.empty + else + val valuesSet = values.to(Set) + traversal.where(evalType(_).filter(valuesSet.contains)) + + def evalTypeNot(regex: String): Iterator[A] = + traversal.where(evalType(_).filterNot(_.matches(regex))) + + def evalTypeNot(regexes: String*): Iterator[A] = + if regexes.isEmpty then Iterator.empty + else + val regexes0 = regexes.map(_.r).toSet + traversal.where(evalType(_).filter(value => !regexes0.exists(_.matches(value)))) + + private def evalType(traversal: Iterator[A]): Iterator[String] = + traversal.flatMap(_._evalTypeOut).flatMap(_._refOut).property(Properties.FULL_NAME) end EvalTypeAccessors diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/propertyaccessors/ModifierAccessors.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/propertyaccessors/ModifierAccessors.scala index fa3d378a..7ced869e 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/propertyaccessors/ModifierAccessors.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/propertyaccessors/ModifierAccessors.scala @@ -8,43 +8,43 @@ import overflowdb.* class ModifierAccessors[A <: Node](val traversal: Iterator[A]) extends AnyVal: - /** Filter: only `public` nodes */ - def isPublic: Iterator[A] = - hasModifier(ModifierTypes.PUBLIC) + /** Filter: only `public` nodes */ + def isPublic: Iterator[A] = + hasModifier(ModifierTypes.PUBLIC) - /** Filter: only `private` nodes */ - def isPrivate: Iterator[A] = - hasModifier(ModifierTypes.PRIVATE) + /** Filter: only `private` nodes */ + def isPrivate: Iterator[A] = + hasModifier(ModifierTypes.PRIVATE) - /** Filter: only `protected` nodes */ - def isProtected: Iterator[A] = - hasModifier(ModifierTypes.PROTECTED) + /** Filter: only `protected` nodes */ + def isProtected: Iterator[A] = + hasModifier(ModifierTypes.PROTECTED) - /** Filter: only `abstract` nodes */ - def isAbstract: Iterator[A] = - hasModifier(ModifierTypes.ABSTRACT) + /** Filter: only `abstract` nodes */ + def isAbstract: Iterator[A] = + hasModifier(ModifierTypes.ABSTRACT) - /** Filter: only `static` nodes */ - def isStatic: Iterator[A] = - hasModifier(ModifierTypes.STATIC) + /** Filter: only `static` nodes */ + def isStatic: Iterator[A] = + hasModifier(ModifierTypes.STATIC) - /** Filter: only `native` nodes */ - def isNative: Iterator[A] = - hasModifier(ModifierTypes.NATIVE) + /** Filter: only `native` nodes */ + def isNative: Iterator[A] = + hasModifier(ModifierTypes.NATIVE) - /** Filter: only `constructor` nodes */ - def isConstructor: Iterator[A] = - hasModifier(ModifierTypes.CONSTRUCTOR) + /** Filter: only `constructor` nodes */ + def isConstructor: Iterator[A] = + hasModifier(ModifierTypes.CONSTRUCTOR) - /** Filter: only `virtual` nodes */ - def isVirtual: Iterator[A] = - hasModifier(ModifierTypes.VIRTUAL) + /** Filter: only `virtual` nodes */ + def isVirtual: Iterator[A] = + hasModifier(ModifierTypes.VIRTUAL) - def hasModifier(modifier: String): Iterator[A] = - traversal.where(_.out.collectAll[Modifier].modifierType(modifier)) + def hasModifier(modifier: String): Iterator[A] = + traversal.where(_.out.collectAll[Modifier].modifierType(modifier)) - /** Traverse to modifiers, e.g., "static", "public". - */ - def modifier: Iterator[Modifier] = - traversal.out.collectAll[Modifier] + /** Traverse to modifiers, e.g., "static", "public". + */ + def modifier: Iterator[Modifier] = + traversal.out.collectAll[Modifier] end ModifierAccessors diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/AnnotationParameterAssignTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/AnnotationParameterAssignTraversal.scala index 00dd5e87..27e3d190 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/AnnotationParameterAssignTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/AnnotationParameterAssignTraversal.scala @@ -8,15 +8,15 @@ import io.shiftleft.semanticcpg.language.* class AnnotationParameterAssignTraversal(val traversal: Iterator[AnnotationParameterAssign]) extends AnyVal: - /** Traverse to all annotation parameters - */ - def parameter: Iterator[AnnotationParameter] = - traversal.flatMap(_._annotationParameterViaAstOut) + /** Traverse to all annotation parameters + */ + def parameter: Iterator[AnnotationParameter] = + traversal.flatMap(_._annotationParameterViaAstOut) - /** Traverse to all values of annotation parameters - */ - def value: Iterator[Expression] = - traversal - .flatMap(_.astOut) - .filterNot(_.isInstanceOf[AnnotationParameter]) - .cast[Expression] + /** Traverse to all values of annotation parameters + */ + def value: Iterator[Expression] = + traversal + .flatMap(_.astOut) + .filterNot(_.isInstanceOf[AnnotationParameter]) + .cast[Expression] diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/AnnotationTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/AnnotationTraversal.scala index 3df4f8c0..85281f32 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/AnnotationTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/AnnotationTraversal.scala @@ -7,28 +7,28 @@ import overflowdb.traversal.* */ class AnnotationTraversal(val traversal: Iterator[nodes.Annotation]) extends AnyVal: - /** Traverse to parameter assignments - */ - def parameterAssign: Iterator[nodes.AnnotationParameterAssign] = - traversal.flatMap(_._annotationParameterAssignViaAstOut) + /** Traverse to parameter assignments + */ + def parameterAssign: Iterator[nodes.AnnotationParameterAssign] = + traversal.flatMap(_._annotationParameterAssignViaAstOut) - /** Traverse to methods annotated with this annotation. - */ - def method: Iterator[nodes.Method] = - traversal.flatMap(_._methodViaAstIn) + /** Traverse to methods annotated with this annotation. + */ + def method: Iterator[nodes.Method] = + traversal.flatMap(_._methodViaAstIn) - /** Traverse to type declarations annotated by this annotation - */ - def typeDecl: Iterator[nodes.TypeDecl] = - traversal.flatMap(_._typeDeclViaAstIn) + /** Traverse to type declarations annotated by this annotation + */ + def typeDecl: Iterator[nodes.TypeDecl] = + traversal.flatMap(_._typeDeclViaAstIn) - /** Traverse to member annotated by this annotation - */ - def member: Iterator[nodes.Member] = - traversal.flatMap(_._memberViaAstIn) + /** Traverse to member annotated by this annotation + */ + def member: Iterator[nodes.Member] = + traversal.flatMap(_._memberViaAstIn) - /** Traverse to parameter annotated by this annotation - */ - def parameter: Iterator[nodes.MethodParameterIn] = - traversal.flatMap(_._methodParameterInViaAstIn) + /** Traverse to parameter annotated by this annotation + */ + def parameter: Iterator[nodes.MethodParameterIn] = + traversal.flatMap(_._methodParameterInViaAstIn) end AnnotationTraversal diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/DependencyTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/DependencyTraversal.scala index 82339ab6..6cc5df83 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/DependencyTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/DependencyTraversal.scala @@ -5,4 +5,4 @@ import io.shiftleft.codepropertygraph.generated.{EdgeTypes, nodes} import io.shiftleft.semanticcpg.language.* class DependencyTraversal(val traversal: Iterator[nodes.Dependency]) extends AnyVal: - def imports: Iterator[Import] = traversal.in(EdgeTypes.IMPORTS).cast[Import] + def imports: Iterator[Import] = traversal.in(EdgeTypes.IMPORTS).cast[Import] diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/FileTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/FileTraversal.scala index 99345de9..79424c89 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/FileTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/FileTraversal.scala @@ -7,8 +7,8 @@ import io.shiftleft.semanticcpg.language.* */ class FileTraversal(val traversal: Iterator[File]) extends AnyVal: - def namespace: Iterator[Namespace] = - traversal.flatMap(_.namespaceBlock).flatMap(_._namespaceViaRefOut) + def namespace: Iterator[Namespace] = + traversal.flatMap(_.namespaceBlock).flatMap(_._namespaceViaRefOut) object FileTraversal: - val UNKNOWN = "" + val UNKNOWN = "" diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/ImportTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/ImportTraversal.scala index 116e104d..74008b5a 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/ImportTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/ImportTraversal.scala @@ -5,9 +5,9 @@ import io.shiftleft.semanticcpg.language.* class ImportTraversal(val traversal: Iterator[Import]) extends AnyVal: - /** Traverse to the call that represents the import in the AST - */ - def call: Iterator[Call] = traversal._isCallForImportIn.cast[Call] + /** Traverse to the call that represents the import in the AST + */ + def call: Iterator[Call] = traversal._isCallForImportIn.cast[Call] - def namespaceBlock: Iterator[NamespaceBlock] = - call.method.namespaceBlock + def namespaceBlock: Iterator[NamespaceBlock] = + call.method.namespaceBlock diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/LocalTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/LocalTraversal.scala index b7cde45c..b51d0020 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/LocalTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/LocalTraversal.scala @@ -8,10 +8,10 @@ import io.shiftleft.semanticcpg.language.* */ class LocalTraversal(val traversal: Iterator[Local]) extends AnyVal: - /** The method hosting this local variable - */ - def method: Iterator[Method] = - // TODO The following line of code is here for backwards compatibility. - // Use the lower commented out line once not required anymore. - traversal.repeat(_.in(EdgeTypes.AST))(_.until(_.hasLabel(NodeTypes.METHOD))).cast[Method] - // definingBlock.method + /** The method hosting this local variable + */ + def method: Iterator[Method] = + // TODO The following line of code is here for backwards compatibility. + // Use the lower commented out line once not required anymore. + traversal.repeat(_.in(EdgeTypes.AST))(_.until(_.hasLabel(NodeTypes.METHOD))).cast[Method] + // definingBlock.method diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/MemberTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/MemberTraversal.scala index e63310aa..fe324237 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/MemberTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/MemberTraversal.scala @@ -8,12 +8,12 @@ import io.shiftleft.semanticcpg.language.* */ class MemberTraversal(val traversal: Iterator[Member]) extends AnyVal: - /** Traverse to annotations of member - */ - def annotation: Iterator[nodes.Annotation] = - traversal.flatMap(_._annotationViaAstOut) + /** Traverse to annotations of member + */ + def annotation: Iterator[nodes.Annotation] = + traversal.flatMap(_._annotationViaAstOut) - /** Places where - */ - def ref: Iterator[Call] = - traversal.flatMap(_._callViaRefIn) + /** Places where + */ + def ref: Iterator[Call] = + traversal.flatMap(_._callViaRefIn) diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/MethodParameterOutTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/MethodParameterOutTraversal.scala index ea21c91a..4c4aa188 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/MethodParameterOutTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/MethodParameterOutTraversal.scala @@ -7,28 +7,28 @@ import scala.jdk.CollectionConverters.* class MethodParameterOutTraversal(val traversal: Iterator[MethodParameterOut]) extends AnyVal: - def paramIn: Iterator[MethodParameterIn] = traversal.flatMap(_.parameterLinkIn.headOption) - - /* method parameter indexes are based, i.e. first parameter has index (that's how java2cpg generates it) */ - def index(num: Int): Iterator[MethodParameterOut] = - traversal.filter { _.index == num } - - /* get all parameters from (and including) - * method parameter indexes are based, i.e. first parameter has index (that's how java2cpg generates it) */ - def indexFrom(num: Int): Iterator[MethodParameterOut] = - traversal.filter(_.index >= num) - - /* get all parameters up to (and including) - * method parameter indexes are based, i.e. first parameter has index (that's how java2cpg generates it) */ - def indexTo(num: Int): Iterator[MethodParameterOut] = - traversal.filter(_.index <= num) - - def argument: Iterator[Expression] = - for - paramOut <- traversal - method = paramOut.method - call <- method.callIn - arg <- call.argumentOut.collectAll[Expression] - if paramOut.parameterLinkIn.index.headOption.contains(arg.argumentIndex) - yield arg + def paramIn: Iterator[MethodParameterIn] = traversal.flatMap(_.parameterLinkIn.headOption) + + /* method parameter indexes are based, i.e. first parameter has index (that's how java2cpg generates it) */ + def index(num: Int): Iterator[MethodParameterOut] = + traversal.filter { _.index == num } + + /* get all parameters from (and including) + * method parameter indexes are based, i.e. first parameter has index (that's how java2cpg generates it) */ + def indexFrom(num: Int): Iterator[MethodParameterOut] = + traversal.filter(_.index >= num) + + /* get all parameters up to (and including) + * method parameter indexes are based, i.e. first parameter has index (that's how java2cpg generates it) */ + def indexTo(num: Int): Iterator[MethodParameterOut] = + traversal.filter(_.index <= num) + + def argument: Iterator[Expression] = + for + paramOut <- traversal + method = paramOut.method + call <- method.callIn + arg <- call.argumentOut.collectAll[Expression] + if paramOut.parameterLinkIn.index.headOption.contains(arg.argumentIndex) + yield arg end MethodParameterOutTraversal diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/MethodParameterTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/MethodParameterTraversal.scala index e1a90d96..de1ce091 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/MethodParameterTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/MethodParameterTraversal.scala @@ -11,28 +11,28 @@ import scala.jdk.CollectionConverters.* @help.Traversal(elementType = classOf[MethodParameterIn]) class MethodParameterTraversal(val traversal: Iterator[MethodParameterIn]) extends AnyVal: - /** Traverse to parameter annotations - */ - def annotation: Iterator[Annotation] = - traversal.flatMap(_._annotationViaAstOut) + /** Traverse to parameter annotations + */ + def annotation: Iterator[Annotation] = + traversal.flatMap(_._annotationViaAstOut) - /** Traverse to all parameters with index greater or equal than `num` - */ - def indexFrom(num: Int): Iterator[MethodParameterIn] = - traversal.filter(_.index >= num) + /** Traverse to all parameters with index greater or equal than `num` + */ + def indexFrom(num: Int): Iterator[MethodParameterIn] = + traversal.filter(_.index >= num) - /** Traverse to all parameters with index smaller or equal than `num` - */ - def indexTo(num: Int): Iterator[MethodParameterIn] = - traversal.filter(_.index <= num) + /** Traverse to all parameters with index smaller or equal than `num` + */ + def indexTo(num: Int): Iterator[MethodParameterIn] = + traversal.filter(_.index <= num) - /** Traverse to arguments (actual parameters) associated with this formal parameter - */ - def argument(implicit callResolver: ICallResolver): Iterator[Expression] = - for - paramIn <- traversal - call <- callResolver.getMethodCallsites(paramIn.method) - arg <- call._argumentOut.collectAll[Expression] - if arg.argumentIndex == paramIn.index - yield arg + /** Traverse to arguments (actual parameters) associated with this formal parameter + */ + def argument(implicit callResolver: ICallResolver): Iterator[Expression] = + for + paramIn <- traversal + call <- callResolver.getMethodCallsites(paramIn.method) + arg <- call._argumentOut.collectAll[Expression] + if arg.argumentIndex == paramIn.index + yield arg end MethodParameterTraversal diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/MethodReturnTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/MethodReturnTraversal.scala index 90f1c4e5..70b0127b 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/MethodReturnTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/MethodReturnTraversal.scala @@ -8,22 +8,22 @@ import overflowdb.traversal.help.Doc @help.Traversal(elementType = classOf[MethodReturn]) class MethodReturnTraversal(val traversal: Iterator[MethodReturn]) extends AnyVal: - @Doc(info = "traverse to parent method") - def method: Iterator[Method] = - traversal.flatMap(_._methodViaAstIn) + @Doc(info = "traverse to parent method") + def method: Iterator[Method] = + traversal.flatMap(_._methodViaAstIn) - def returnUser(implicit callResolver: ICallResolver): Iterator[Call] = - traversal.flatMap(_.returnUser) + def returnUser(implicit callResolver: ICallResolver): Iterator[Call] = + traversal.flatMap(_.returnUser) - /** Traverse to last expressions in CFG. Can be multiple. - */ - @Doc(info = "traverse to last expressions in CFG (can be multiple)") - def cfgLast: Iterator[CfgNode] = - traversal.flatMap(_.cfgIn) + /** Traverse to last expressions in CFG. Can be multiple. + */ + @Doc(info = "traverse to last expressions in CFG (can be multiple)") + def cfgLast: Iterator[CfgNode] = + traversal.flatMap(_.cfgIn) - /** Traverse to return type - */ - @Doc(info = "traverse to return type") - def typ: Iterator[Type] = - traversal.flatMap(_.evalTypeOut) + /** Traverse to return type + */ + @Doc(info = "traverse to return type") + def typ: Iterator[Type] = + traversal.flatMap(_.evalTypeOut) end MethodReturnTraversal diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/MethodTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/MethodTraversal.scala index 66446311..e946f683 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/MethodTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/MethodTraversal.scala @@ -21,12 +21,12 @@ case class MethodSubGraph( filename: String, nodes: Set[Node] ): - def edges: Set[Edge] = - for - node <- nodes - edge <- node.bothE.asScala - if nodes.contains(edge.inNode) && nodes.contains(edge.outNode) - yield edge + def edges: Set[Edge] = + for + node <- nodes + edge <- node.bothE.asScala + if nodes.contains(edge.inNode) && nodes.contains(edge.outNode) + yield edge def plus(resultA: ExportResult, resultB: ExportResult): ExportResult = ExportResult( @@ -41,254 +41,254 @@ def plus(resultA: ExportResult, resultB: ExportResult): ExportResult = @help.Traversal(elementType = classOf[Method]) class MethodTraversal(val traversal: Iterator[Method]) extends AnyVal: - /** Traverse to annotations of method - */ - def annotation: Iterator[nodes.Annotation] = - traversal.flatMap(_._annotationViaAstOut) - - /** All control structures of this method - */ - @Doc(info = "Control structures (source frontends only)") - def controlStructure: Iterator[ControlStructure] = - traversal.ast.isControlStructure - - /** Shorthand to traverse to control structures where condition matches `regex` - */ - def controlStructure(regex: String): Iterator[ControlStructure] = - traversal.ast.isControlStructure.code(regex) - - @Doc(info = "All try blocks (`ControlStructure` nodes)") - def tryBlock: Iterator[ControlStructure] = - controlStructure.isTry - - @Doc(info = "All if blocks (`ControlStructure` nodes)") - def ifBlock: Iterator[ControlStructure] = - controlStructure.isIf - - @Doc(info = "All else blocks (`ControlStructure` nodes)") - def elseBlock: Iterator[ControlStructure] = - controlStructure.isElse - - @Doc(info = "All switch blocks (`ControlStructure` nodes)") - def switchBlock: Iterator[ControlStructure] = - controlStructure.isSwitch - - @Doc(info = "All do blocks (`ControlStructure` nodes)") - def doBlock: Iterator[ControlStructure] = - controlStructure.isDo - - @Doc(info = "All for blocks (`ControlStructure` nodes)") - def forBlock: Iterator[ControlStructure] = - controlStructure.isFor - - @Doc(info = "All while blocks (`ControlStructure` nodes)") - def whileBlock: Iterator[ControlStructure] = - controlStructure.isWhile - - @Doc(info = "All gotos (`ControlStructure` nodes)") - def goto: Iterator[ControlStructure] = - controlStructure.isGoto - - @Doc(info = "All breaks (`ControlStructure` nodes)") - def break: Iterator[ControlStructure] = - controlStructure.isBreak - - @Doc(info = "All continues (`ControlStructure` nodes)") - def continue: Iterator[ControlStructure] = - controlStructure.isContinue - - @Doc(info = "All throws (`ControlStructure` nodes)") - def throws: Iterator[ControlStructure] = - controlStructure.isThrow - - /** The type declaration associated with this method, e.g., the class it is defined in. - */ - @Doc(info = "Type this method is defined in") - def definingTypeDecl: Iterator[TypeDecl] = - traversal - .repeat(_._astIn)(_.until(_.collectAll[TypeDecl])) - .cast[TypeDecl] - - /** The type declaration associated with this method, e.g., the class it is defined in. Alias - * for 'definingTypeDecl' - */ - @Doc(info = "Type this method is defined in - alias for 'definingTypeDecl'") - def typeDecl: Iterator[TypeDecl] = definingTypeDecl - - /** The method in which this method is defined - */ - @Doc(info = "Method this method is defined in") - def definingMethod: Iterator[Method] = - traversal - .repeat(_._astIn)(_.until(_.collectAll[Method])) - .cast[Method] - - /** Traverse only to methods that are stubs, e.g., their code is not available or the method - * body is empty. - */ - def isStub: Iterator[Method] = - traversal.where(_.not(_._cfgOut.not(_.collectAll[MethodReturn]))) - - /** Traverse only to methods that are not stubs. - */ - def isNotStub: Iterator[Method] = - traversal.where(_._cfgOut.not(_.collectAll[MethodReturn])) - - /** Traverse only to methods that accept variadic arguments. - */ - def isVariadic: Iterator[Method] = - traversal.filter(_.isVariadic) - - /** Traverse to external methods, that is, methods not present but only referenced in the CPG. - */ - @Doc(info = "External methods (called, but no body available)") - def external: Iterator[Method] = - traversal.isExternal(true) - - /** Traverse to internal methods, that is, methods for which code is included in this CPG. - */ - @Doc(info = "Internal methods, i.e., a body is available") - def internal: Iterator[Method] = - traversal.isExternal(false) - - /** Traverse to the methods local variables - */ - @Doc(info = "Local variables declared in the method") - def local: Iterator[Local] = - traversal.block.ast.isLocal - - @Doc(info = "Top level expressions (\"Statements\")") - def topLevelExpressions: Iterator[Expression] = - traversal._astOut - .collectAll[Block] - ._astOut - .not(_.collectAll[Local]) - .cast[Expression] - - @Doc(info = "Control flow graph nodes") - def cfgNode: Iterator[CfgNode] = - traversal.flatMap(_.cfgNode) - - /** Traverse to last expression in CFG. - */ - @Doc(info = "Last control flow graph node") - def cfgLast: Iterator[CfgNode] = - traversal.methodReturn.cfgLast - - /** Traverse to method body (alias for `block`) */ - @Doc(info = "Alias for `block`") - def body: Iterator[Block] = - traversal.block - - /** Traverse to namespace */ - @Doc(info = "Namespace this method is declared in") - def namespace: Iterator[Namespace] = - traversal.namespaceBlock.namespace - - /** Traverse to namespace block */ - @Doc(info = "Namespace block this method is declared in") - def namespaceBlock: Iterator[NamespaceBlock] = - traversal.flatMap { m => - m.astIn.headOption match - // some language frontends don't have a TYPE_DECL for a METHOD - case Some(namespaceBlock: NamespaceBlock) => namespaceBlock.start - // other language frontends always embed their method in a TYPE_DECL - case _ => m.definingTypeDecl.namespaceBlock + /** Traverse to annotations of method + */ + def annotation: Iterator[nodes.Annotation] = + traversal.flatMap(_._annotationViaAstOut) + + /** All control structures of this method + */ + @Doc(info = "Control structures (source frontends only)") + def controlStructure: Iterator[ControlStructure] = + traversal.ast.isControlStructure + + /** Shorthand to traverse to control structures where condition matches `regex` + */ + def controlStructure(regex: String): Iterator[ControlStructure] = + traversal.ast.isControlStructure.code(regex) + + @Doc(info = "All try blocks (`ControlStructure` nodes)") + def tryBlock: Iterator[ControlStructure] = + controlStructure.isTry + + @Doc(info = "All if blocks (`ControlStructure` nodes)") + def ifBlock: Iterator[ControlStructure] = + controlStructure.isIf + + @Doc(info = "All else blocks (`ControlStructure` nodes)") + def elseBlock: Iterator[ControlStructure] = + controlStructure.isElse + + @Doc(info = "All switch blocks (`ControlStructure` nodes)") + def switchBlock: Iterator[ControlStructure] = + controlStructure.isSwitch + + @Doc(info = "All do blocks (`ControlStructure` nodes)") + def doBlock: Iterator[ControlStructure] = + controlStructure.isDo + + @Doc(info = "All for blocks (`ControlStructure` nodes)") + def forBlock: Iterator[ControlStructure] = + controlStructure.isFor + + @Doc(info = "All while blocks (`ControlStructure` nodes)") + def whileBlock: Iterator[ControlStructure] = + controlStructure.isWhile + + @Doc(info = "All gotos (`ControlStructure` nodes)") + def goto: Iterator[ControlStructure] = + controlStructure.isGoto + + @Doc(info = "All breaks (`ControlStructure` nodes)") + def break: Iterator[ControlStructure] = + controlStructure.isBreak + + @Doc(info = "All continues (`ControlStructure` nodes)") + def continue: Iterator[ControlStructure] = + controlStructure.isContinue + + @Doc(info = "All throws (`ControlStructure` nodes)") + def throws: Iterator[ControlStructure] = + controlStructure.isThrow + + /** The type declaration associated with this method, e.g., the class it is defined in. + */ + @Doc(info = "Type this method is defined in") + def definingTypeDecl: Iterator[TypeDecl] = + traversal + .repeat(_._astIn)(_.until(_.collectAll[TypeDecl])) + .cast[TypeDecl] + + /** The type declaration associated with this method, e.g., the class it is defined in. Alias for + * 'definingTypeDecl' + */ + @Doc(info = "Type this method is defined in - alias for 'definingTypeDecl'") + def typeDecl: Iterator[TypeDecl] = definingTypeDecl + + /** The method in which this method is defined + */ + @Doc(info = "Method this method is defined in") + def definingMethod: Iterator[Method] = + traversal + .repeat(_._astIn)(_.until(_.collectAll[Method])) + .cast[Method] + + /** Traverse only to methods that are stubs, e.g., their code is not available or the method body + * is empty. + */ + def isStub: Iterator[Method] = + traversal.where(_.not(_._cfgOut.not(_.collectAll[MethodReturn]))) + + /** Traverse only to methods that are not stubs. + */ + def isNotStub: Iterator[Method] = + traversal.where(_._cfgOut.not(_.collectAll[MethodReturn])) + + /** Traverse only to methods that accept variadic arguments. + */ + def isVariadic: Iterator[Method] = + traversal.filter(_.isVariadic) + + /** Traverse to external methods, that is, methods not present but only referenced in the CPG. + */ + @Doc(info = "External methods (called, but no body available)") + def external: Iterator[Method] = + traversal.isExternal(true) + + /** Traverse to internal methods, that is, methods for which code is included in this CPG. + */ + @Doc(info = "Internal methods, i.e., a body is available") + def internal: Iterator[Method] = + traversal.isExternal(false) + + /** Traverse to the methods local variables + */ + @Doc(info = "Local variables declared in the method") + def local: Iterator[Local] = + traversal.block.ast.isLocal + + @Doc(info = "Top level expressions (\"Statements\")") + def topLevelExpressions: Iterator[Expression] = + traversal._astOut + .collectAll[Block] + ._astOut + .not(_.collectAll[Local]) + .cast[Expression] + + @Doc(info = "Control flow graph nodes") + def cfgNode: Iterator[CfgNode] = + traversal.flatMap(_.cfgNode) + + /** Traverse to last expression in CFG. + */ + @Doc(info = "Last control flow graph node") + def cfgLast: Iterator[CfgNode] = + traversal.methodReturn.cfgLast + + /** Traverse to method body (alias for `block`) */ + @Doc(info = "Alias for `block`") + def body: Iterator[Block] = + traversal.block + + /** Traverse to namespace */ + @Doc(info = "Namespace this method is declared in") + def namespace: Iterator[Namespace] = + traversal.namespaceBlock.namespace + + /** Traverse to namespace block */ + @Doc(info = "Namespace block this method is declared in") + def namespaceBlock: Iterator[NamespaceBlock] = + traversal.flatMap { m => + m.astIn.headOption match + // some language frontends don't have a TYPE_DECL for a METHOD + case Some(namespaceBlock: NamespaceBlock) => namespaceBlock.start + // other language frontends always embed their method in a TYPE_DECL + case _ => m.definingTypeDecl.namespaceBlock + } + + def numberOfLines: Iterator[Int] = traversal.map(_.numberOfLines) + + def sanitizeFilename(filename: String) = + Paths.get(filename).getFileName.toString.replaceAll("[^a-zA-Z0-9-_\\.]", "_") + + def getOrCreateExportPath(pathToUse: String): String = + try + if pathToUse == null then + Files.createTempDirectory("graph-export").toAbsolutePath.toString + else + Paths.get(pathToUse).toFile.mkdirs() + pathToUse + catch + case exc: Exception => pathToUse + + @Doc(info = "Export the methods to graphml") + def gml(gmlDir: String = null): ExportResult = + var pathToUse = getOrCreateExportPath(gmlDir) + traversal + .map { method => + MethodSubGraph( + methodName = method.name, + methodFullName = method.fullName, + filename = method.location.filename, + nodes = method.ast.toSet + ) } + .map { case subGraph @ MethodSubGraph(methodName, methodFullName, filename, nodes) => + val methodHash = Fingerprinting.calculate_hash(methodFullName) + GraphMLExporter.runExport( + nodes, + subGraph.edges, + Paths.get( + pathToUse, + s"${methodName}-${sanitizeFilename(filename)}-${methodHash.slice(0, 8)}.graphml" + ) + ) + } + .reduce(plus) + end gml + + def gml: ExportResult = gml(null) + + def dot(dotDir: String = null): ExportResult = + var pathToUse = getOrCreateExportPath(dotDir) + traversal + .map { method => + MethodSubGraph( + methodName = method.name, + methodFullName = method.fullName, + filename = method.location.filename, + nodes = method.ast.toSet + ) + } + .map { case subGraph @ MethodSubGraph(methodName, methodFullName, filename, nodes) => + val methodHash = Fingerprinting.calculate_hash(methodFullName) + DotExporter.runExport( + nodes, + subGraph.edges, + Paths.get( + pathToUse, + s"${methodName}-${sanitizeFilename(filename)}-${methodHash.slice(0, 8)}.dot" + ) + ) + } + .reduce(plus) + end dot + + def dot: ExportResult = dot(null) + + def exportAllRepr(dotCfgDir: String = null): Unit = + var pathToUse = getOrCreateExportPath(dotCfgDir) + traversal + .foreach { method => + val methodName = method.name + val methodFullName = method.fullName + val filename = method.location.filename + val methodHash = Fingerprinting.calculate_hash(methodFullName) + File( + pathToUse, + s"${methodName}-${sanitizeFilename(filename)}-${methodHash.slice(0, 8)}-ast.dot" + ).write(method.dotAst.head) + File( + pathToUse, + s"${methodName}-${sanitizeFilename(filename)}-${methodHash.slice(0, 8)}-cdg.dot" + ).write(method.dotCdg.head) + File( + pathToUse, + s"${methodName}-${sanitizeFilename(filename)}-${methodHash.slice(0, 8)}-cfg.dot" + ).write(method.dotCfg.head) + } + end exportAllRepr - def numberOfLines: Iterator[Int] = traversal.map(_.numberOfLines) - - def sanitizeFilename(filename: String) = - Paths.get(filename).getFileName.toString.replaceAll("[^a-zA-Z0-9-_\\.]", "_") - - def getOrCreateExportPath(pathToUse: String): String = - try - if pathToUse == null then - Files.createTempDirectory("graph-export").toAbsolutePath.toString - else - Paths.get(pathToUse).toFile.mkdirs() - pathToUse - catch - case exc: Exception => pathToUse - - @Doc(info = "Export the methods to graphml") - def gml(gmlDir: String = null): ExportResult = - var pathToUse = getOrCreateExportPath(gmlDir) - traversal - .map { method => - MethodSubGraph( - methodName = method.name, - methodFullName = method.fullName, - filename = method.location.filename, - nodes = method.ast.toSet - ) - } - .map { case subGraph @ MethodSubGraph(methodName, methodFullName, filename, nodes) => - val methodHash = Fingerprinting.calculate_hash(methodFullName) - GraphMLExporter.runExport( - nodes, - subGraph.edges, - Paths.get( - pathToUse, - s"${methodName}-${sanitizeFilename(filename)}-${methodHash.slice(0, 8)}.graphml" - ) - ) - } - .reduce(plus) - end gml - - def gml: ExportResult = gml(null) - - def dot(dotDir: String = null): ExportResult = - var pathToUse = getOrCreateExportPath(dotDir) - traversal - .map { method => - MethodSubGraph( - methodName = method.name, - methodFullName = method.fullName, - filename = method.location.filename, - nodes = method.ast.toSet - ) - } - .map { case subGraph @ MethodSubGraph(methodName, methodFullName, filename, nodes) => - val methodHash = Fingerprinting.calculate_hash(methodFullName) - DotExporter.runExport( - nodes, - subGraph.edges, - Paths.get( - pathToUse, - s"${methodName}-${sanitizeFilename(filename)}-${methodHash.slice(0, 8)}.dot" - ) - ) - } - .reduce(plus) - end dot - - def dot: ExportResult = dot(null) - - def exportAllRepr(dotCfgDir: String = null): Unit = - var pathToUse = getOrCreateExportPath(dotCfgDir) - traversal - .foreach { method => - val methodName = method.name - val methodFullName = method.fullName - val filename = method.location.filename - val methodHash = Fingerprinting.calculate_hash(methodFullName) - File( - pathToUse, - s"${methodName}-${sanitizeFilename(filename)}-${methodHash.slice(0, 8)}-ast.dot" - ).write(method.dotAst.head) - File( - pathToUse, - s"${methodName}-${sanitizeFilename(filename)}-${methodHash.slice(0, 8)}-cdg.dot" - ).write(method.dotCdg.head) - File( - pathToUse, - s"${methodName}-${sanitizeFilename(filename)}-${methodHash.slice(0, 8)}-cfg.dot" - ).write(method.dotCfg.head) - } - end exportAllRepr - - def exportAllRepr: Unit = exportAllRepr(null) + def exportAllRepr: Unit = exportAllRepr(null) end MethodTraversal diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/NamespaceBlockTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/NamespaceBlockTraversal.scala index 1cbe0397..856b981f 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/NamespaceBlockTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/NamespaceBlockTraversal.scala @@ -5,15 +5,15 @@ import io.shiftleft.semanticcpg.language.* class NamespaceBlockTraversal(val traversal: Iterator[NamespaceBlock]) extends AnyVal: - /** Namespaces for namespace blocks. - */ - def namespace: Iterator[Namespace] = - traversal.flatMap(_.refOut) + /** Namespaces for namespace blocks. + */ + def namespace: Iterator[Namespace] = + traversal.flatMap(_.refOut) - /** The type declarations defined in this namespace - */ - def typeDecl: Iterator[TypeDecl] = - traversal.flatMap(_._typeDeclViaAstOut) + /** The type declarations defined in this namespace + */ + def typeDecl: Iterator[TypeDecl] = + traversal.flatMap(_._typeDeclViaAstOut) - def method: Iterator[Method] = - traversal.flatMap(_._methodViaAstOut) + def method: Iterator[Method] = + traversal.flatMap(_._methodViaAstOut) diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/NamespaceTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/NamespaceTraversal.scala index d8376151..2fc5a048 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/NamespaceTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/NamespaceTraversal.scala @@ -7,26 +7,26 @@ import io.shiftleft.semanticcpg.language.* */ class NamespaceTraversal(val traversal: Iterator[Namespace]) extends AnyVal: - /** The type declarations defined in this namespace - */ - def typeDecl: Iterator[TypeDecl] = - traversal.flatMap(_.refIn).flatMap(_._typeDeclViaAstOut) + /** The type declarations defined in this namespace + */ + def typeDecl: Iterator[TypeDecl] = + traversal.flatMap(_.refIn).flatMap(_._typeDeclViaAstOut) - /** Methods defined in this namespace - */ - def method: Iterator[Method] = - traversal.flatMap(_.refIn).flatMap(_._methodViaAstOut) + /** Methods defined in this namespace + */ + def method: Iterator[Method] = + traversal.flatMap(_.refIn).flatMap(_._methodViaAstOut) - /** External namespaces - any namespaces which contain one or more external type. - */ - def external: Iterator[Namespace] = - traversal.where(_.typeDecl.external) + /** External namespaces - any namespaces which contain one or more external type. + */ + def external: Iterator[Namespace] = + traversal.where(_.typeDecl.external) - /** Internal namespaces - any namespaces which contain one or more internal type - */ - def internal: Iterator[Namespace] = - traversal.where(_.typeDecl.internal) + /** Internal namespaces - any namespaces which contain one or more internal type + */ + def internal: Iterator[Namespace] = + traversal.where(_.typeDecl.internal) end NamespaceTraversal object NamespaceTraversal: - val globalNamespaceName = "" + val globalNamespaceName = "" diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/TypeDeclTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/TypeDeclTraversal.scala index 1250b5f4..fddbe738 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/TypeDeclTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/TypeDeclTraversal.scala @@ -7,107 +7,107 @@ import io.shiftleft.semanticcpg.language.* /** Type declaration - possibly a template that requires instantiation */ class TypeDeclTraversal(val traversal: Iterator[TypeDecl]) extends AnyVal: - import TypeDeclTraversal.* - - /** Annotations of the type declaration - */ - def annotation: Iterator[nodes.Annotation] = - traversal.flatMap(_._annotationViaAstOut) - - /** Types referencing to this type declaration. - */ - def referencingType: Iterator[Type] = - traversal.flatMap(_.refIn) - - /** Namespace in which this type declaration is defined - */ - def namespace: Iterator[Namespace] = - traversal.flatMap(_.namespaceBlock).namespace - - /** Methods defined as part of this type - */ - def method: Iterator[Method] = - canonicalType.flatMap(_._methodViaAstOut) - - /** Filter for type declarations contained in the analyzed code. - */ - def internal: Iterator[TypeDecl] = - canonicalType.isExternal(false) - - /** Filter for type declarations not contained in the analyzed code. - */ - def external: Iterator[TypeDecl] = - canonicalType.isExternal(true) - - /** Member variables - */ - def member: Iterator[Member] = - canonicalType.flatMap(_._memberViaAstOut) - - /** Direct base types in the inheritance graph. - */ - def baseType: Iterator[Type] = - canonicalType.flatMap(_._typeViaInheritsFromOut) - - /** Direct base type declaration. - */ - def derivedTypeDecl: Iterator[TypeDecl] = - referencingType.derivedTypeDecl - - /** Direct and transitive base type declaration. - */ - def derivedTypeDeclTransitive: Iterator[TypeDecl] = - traversal.repeat(_.derivedTypeDecl)(_.emitAllButFirst) - - /** Direct base type declaration. - */ - def baseTypeDecl: Iterator[TypeDecl] = - traversal.baseType.referencedTypeDecl - - /** Direct and transitive base type declaration. - */ - def baseTypeDeclTransitive: Iterator[TypeDecl] = - traversal.repeat(_.baseTypeDecl)(_.emitAllButFirst) - - /** Traverse to alias type declarations. - */ - def isAlias: Iterator[TypeDecl] = - traversal.filter(_.aliasTypeFullName.isDefined) - - /** Traverse to canonical type declarations. - */ - def isCanonical: Iterator[TypeDecl] = - traversal.filter(_.aliasTypeFullName.isEmpty) - - /** If this is an alias type declaration, go to its underlying type declaration else unchanged. - */ - def unravelAlias: Iterator[TypeDecl] = - traversal.map { typeDecl => - val alias = - for - tpe <- typeDecl.aliasedType.nextOption() - typeDecl <- tpe.referencedTypeDecl.nextOption() - yield typeDecl - - alias.getOrElse(typeDecl) - } - - /** Traverse to canonical type which means unravel aliases until we find a non alias type - * declaration. - */ - def canonicalType: Iterator[TypeDecl] = - traversal.repeat(_.unravelAlias)(_.until(_.isCanonical).maxDepth(maxAliasExpansions)) - - /** Direct alias type declarations. - */ - def aliasTypeDecl: Iterator[TypeDecl] = - referencingType.aliasTypeDecl - - /** Direct and transitive alias type declarations. - */ - def aliasTypeDeclTransitive: Iterator[TypeDecl] = - traversal.repeat(_.aliasTypeDecl)(_.emitAllButFirst) + import TypeDeclTraversal.* + + /** Annotations of the type declaration + */ + def annotation: Iterator[nodes.Annotation] = + traversal.flatMap(_._annotationViaAstOut) + + /** Types referencing to this type declaration. + */ + def referencingType: Iterator[Type] = + traversal.flatMap(_.refIn) + + /** Namespace in which this type declaration is defined + */ + def namespace: Iterator[Namespace] = + traversal.flatMap(_.namespaceBlock).namespace + + /** Methods defined as part of this type + */ + def method: Iterator[Method] = + canonicalType.flatMap(_._methodViaAstOut) + + /** Filter for type declarations contained in the analyzed code. + */ + def internal: Iterator[TypeDecl] = + canonicalType.isExternal(false) + + /** Filter for type declarations not contained in the analyzed code. + */ + def external: Iterator[TypeDecl] = + canonicalType.isExternal(true) + + /** Member variables + */ + def member: Iterator[Member] = + canonicalType.flatMap(_._memberViaAstOut) + + /** Direct base types in the inheritance graph. + */ + def baseType: Iterator[Type] = + canonicalType.flatMap(_._typeViaInheritsFromOut) + + /** Direct base type declaration. + */ + def derivedTypeDecl: Iterator[TypeDecl] = + referencingType.derivedTypeDecl + + /** Direct and transitive base type declaration. + */ + def derivedTypeDeclTransitive: Iterator[TypeDecl] = + traversal.repeat(_.derivedTypeDecl)(_.emitAllButFirst) + + /** Direct base type declaration. + */ + def baseTypeDecl: Iterator[TypeDecl] = + traversal.baseType.referencedTypeDecl + + /** Direct and transitive base type declaration. + */ + def baseTypeDeclTransitive: Iterator[TypeDecl] = + traversal.repeat(_.baseTypeDecl)(_.emitAllButFirst) + + /** Traverse to alias type declarations. + */ + def isAlias: Iterator[TypeDecl] = + traversal.filter(_.aliasTypeFullName.isDefined) + + /** Traverse to canonical type declarations. + */ + def isCanonical: Iterator[TypeDecl] = + traversal.filter(_.aliasTypeFullName.isEmpty) + + /** If this is an alias type declaration, go to its underlying type declaration else unchanged. + */ + def unravelAlias: Iterator[TypeDecl] = + traversal.map { typeDecl => + val alias = + for + tpe <- typeDecl.aliasedType.nextOption() + typeDecl <- tpe.referencedTypeDecl.nextOption() + yield typeDecl + + alias.getOrElse(typeDecl) + } + + /** Traverse to canonical type which means unravel aliases until we find a non alias type + * declaration. + */ + def canonicalType: Iterator[TypeDecl] = + traversal.repeat(_.unravelAlias)(_.until(_.isCanonical).maxDepth(maxAliasExpansions)) + + /** Direct alias type declarations. + */ + def aliasTypeDecl: Iterator[TypeDecl] = + referencingType.aliasTypeDecl + + /** Direct and transitive alias type declarations. + */ + def aliasTypeDeclTransitive: Iterator[TypeDecl] = + traversal.repeat(_.aliasTypeDecl)(_.emitAllButFirst) end TypeDeclTraversal object TypeDeclTraversal: - private val maxAliasExpansions = 100 + private val maxAliasExpansions = 100 diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/TypeTraversal.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/TypeTraversal.scala index eaedbd0c..2db6004d 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/TypeTraversal.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/language/types/structure/TypeTraversal.scala @@ -6,88 +6,88 @@ import io.shiftleft.semanticcpg.language.* class TypeTraversal(val traversal: Iterator[Type]) extends AnyVal: - /** Annotations of the corresponding type declaration. - */ - def annotation: Iterator[nodes.Annotation] = - traversal.referencedTypeDecl.annotation - - /** Namespaces in which the corresponding type declaration is defined. - */ - def namespace: Iterator[Namespace] = - traversal.referencedTypeDecl.namespace - - /** Methods defined on the corresponding type declaration. - */ - def method: Iterator[Method] = - traversal.referencedTypeDecl.method - - /** Filter for types whos corresponding type declaration is in the analyzed jar. - */ - def internal: Iterator[Type] = - traversal.where(_.referencedTypeDecl.internal) - - /** Filter for types whos corresponding type declaration is not in the analyzed jar. - */ - def external: Iterator[Type] = - traversal.where(_.referencedTypeDecl.external) - - /** Member variables of the corresponding type declaration. - */ - def member: Iterator[Member] = - traversal.referencedTypeDecl.member - - /** Direct base types of the corresponding type declaration in the inheritance graph. - */ - def baseType: Iterator[Type] = - traversal.referencedTypeDecl.baseType - - /** Direct and transitive base types of the corresponding type declaration. - */ - def baseTypeTransitive: Iterator[Type] = - traversal.repeat(_.baseType)(_.emitAllButFirst) - - /** Direct derived types. - */ - def derivedType: Iterator[Type] = - derivedTypeDecl.referencingType - - /** Direct and transitive derived types. - */ - def derivedTypeTransitive: Iterator[Type] = - traversal.repeat(_.derivedType)(_.emitAllButFirst) - - /** Type declarations which derive from this type. - */ - def derivedTypeDecl: Iterator[TypeDecl] = - traversal.flatMap(_.inheritsFromIn) - - /** Direct alias types. - */ - def aliasType: Iterator[Type] = - traversal.aliasTypeDecl.referencingType - - /** Direct and transitive alias types. - */ - def aliasTypeTransitive: Iterator[Type] = - traversal.repeat(_.aliasType)(_.emitAllButFirst) - - def localOfType: Iterator[Local] = - traversal.flatMap(_._localViaEvalTypeIn) - - def memberOfType: Iterator[Member] = - traversal.flatMap(_.evalTypeIn).collectAll[Member] - - @deprecated("Please use `parameterOfType`") - def parameter: Iterator[MethodParameterIn] = parameterOfType - - def parameterOfType: Iterator[MethodParameterIn] = - traversal.flatMap(_.evalTypeIn).collectAll[MethodParameterIn] - - def methodReturnOfType: Iterator[MethodReturn] = - traversal.flatMap(_.evalTypeIn).collectAll[MethodReturn] - - def expressionOfType: Iterator[Expression] = expression - - def expression: Iterator[Expression] = - traversal.flatMap(_.evalTypeIn).collectAll[Expression] + /** Annotations of the corresponding type declaration. + */ + def annotation: Iterator[nodes.Annotation] = + traversal.referencedTypeDecl.annotation + + /** Namespaces in which the corresponding type declaration is defined. + */ + def namespace: Iterator[Namespace] = + traversal.referencedTypeDecl.namespace + + /** Methods defined on the corresponding type declaration. + */ + def method: Iterator[Method] = + traversal.referencedTypeDecl.method + + /** Filter for types whos corresponding type declaration is in the analyzed jar. + */ + def internal: Iterator[Type] = + traversal.where(_.referencedTypeDecl.internal) + + /** Filter for types whos corresponding type declaration is not in the analyzed jar. + */ + def external: Iterator[Type] = + traversal.where(_.referencedTypeDecl.external) + + /** Member variables of the corresponding type declaration. + */ + def member: Iterator[Member] = + traversal.referencedTypeDecl.member + + /** Direct base types of the corresponding type declaration in the inheritance graph. + */ + def baseType: Iterator[Type] = + traversal.referencedTypeDecl.baseType + + /** Direct and transitive base types of the corresponding type declaration. + */ + def baseTypeTransitive: Iterator[Type] = + traversal.repeat(_.baseType)(_.emitAllButFirst) + + /** Direct derived types. + */ + def derivedType: Iterator[Type] = + derivedTypeDecl.referencingType + + /** Direct and transitive derived types. + */ + def derivedTypeTransitive: Iterator[Type] = + traversal.repeat(_.derivedType)(_.emitAllButFirst) + + /** Type declarations which derive from this type. + */ + def derivedTypeDecl: Iterator[TypeDecl] = + traversal.flatMap(_.inheritsFromIn) + + /** Direct alias types. + */ + def aliasType: Iterator[Type] = + traversal.aliasTypeDecl.referencingType + + /** Direct and transitive alias types. + */ + def aliasTypeTransitive: Iterator[Type] = + traversal.repeat(_.aliasType)(_.emitAllButFirst) + + def localOfType: Iterator[Local] = + traversal.flatMap(_._localViaEvalTypeIn) + + def memberOfType: Iterator[Member] = + traversal.flatMap(_.evalTypeIn).collectAll[Member] + + @deprecated("Please use `parameterOfType`") + def parameter: Iterator[MethodParameterIn] = parameterOfType + + def parameterOfType: Iterator[MethodParameterIn] = + traversal.flatMap(_.evalTypeIn).collectAll[MethodParameterIn] + + def methodReturnOfType: Iterator[MethodReturn] = + traversal.flatMap(_.evalTypeIn).collectAll[MethodReturn] + + def expressionOfType: Iterator[Expression] = expression + + def expression: Iterator[Expression] = + traversal.flatMap(_.evalTypeIn).collectAll[Expression] end TypeTraversal diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/layers/LayerCreator.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/layers/LayerCreator.scala index 8e1bac2a..8347e172 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/layers/LayerCreator.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/layers/LayerCreator.scala @@ -9,51 +9,51 @@ import org.slf4j.{Logger, LoggerFactory} abstract class LayerCreator: - private val logger: Logger = LoggerFactory.getLogger(this.getClass) - - val overlayName: String - val description: String - val dependsOn: List[String] = List() - - /** If the LayerCreator modifies the CPG, then we store its name in the CPGs metadata and - * disallow rerunning the creator, that is, applying the layer twice. - */ - protected val storeOverlayName: Boolean = true - - def run(context: LayerCreatorContext, storeUndoInfo: Boolean = false): Unit = - val appliedOverlays = Overlays.appliedOverlays(context.cpg).toSet - if !dependsOn.toSet.subsetOf(appliedOverlays) then - logger.warn( - s"${this.getClass.getName} depends on $dependsOn but CPG only has $appliedOverlays - skipping creation" - ) - else if appliedOverlays.contains(overlayName) then - logger.warn(s"The overlay $overlayName already exists - skipping creation") - else - create(context, storeUndoInfo) - if storeOverlayName then - Overlays.appendOverlayName(context.cpg, overlayName) - - protected def initSerializedCpg( - outputDir: Option[String], - passName: String, - index: Int = 0 - ): SerializedCpg = - outputDir match - case Some(dir) => - new SerializedCpg((File(dir) / s"${index}_$passName").path.toAbsolutePath.toString) - case None => new SerializedCpg() - - protected def runPass( - pass: CpgPassBase, - context: LayerCreatorContext, - storeUndoInfo: Boolean, - index: Int = 0 - ): Unit = - val serializedCpg = initSerializedCpg(context.outputDir, pass.name, index) - pass.createApplySerializeAndStore(serializedCpg, inverse = storeUndoInfo) - serializedCpg.close() - - def create(context: LayerCreatorContext, storeUndoInfo: Boolean = false): Unit + private val logger: Logger = LoggerFactory.getLogger(this.getClass) + + val overlayName: String + val description: String + val dependsOn: List[String] = List() + + /** If the LayerCreator modifies the CPG, then we store its name in the CPGs metadata and disallow + * rerunning the creator, that is, applying the layer twice. + */ + protected val storeOverlayName: Boolean = true + + def run(context: LayerCreatorContext, storeUndoInfo: Boolean = false): Unit = + val appliedOverlays = Overlays.appliedOverlays(context.cpg).toSet + if !dependsOn.toSet.subsetOf(appliedOverlays) then + logger.warn( + s"${this.getClass.getName} depends on $dependsOn but CPG only has $appliedOverlays - skipping creation" + ) + else if appliedOverlays.contains(overlayName) then + logger.warn(s"The overlay $overlayName already exists - skipping creation") + else + create(context, storeUndoInfo) + if storeOverlayName then + Overlays.appendOverlayName(context.cpg, overlayName) + + protected def initSerializedCpg( + outputDir: Option[String], + passName: String, + index: Int = 0 + ): SerializedCpg = + outputDir match + case Some(dir) => + new SerializedCpg((File(dir) / s"${index}_$passName").path.toAbsolutePath.toString) + case None => new SerializedCpg() + + protected def runPass( + pass: CpgPassBase, + context: LayerCreatorContext, + storeUndoInfo: Boolean, + index: Int = 0 + ): Unit = + val serializedCpg = initSerializedCpg(context.outputDir, pass.name, index) + pass.createApplySerializeAndStore(serializedCpg, inverse = storeUndoInfo) + serializedCpg.close() + + def create(context: LayerCreatorContext, storeUndoInfo: Boolean = false): Unit end LayerCreator class LayerCreatorContext(val cpg: Cpg, val outputDir: Option[String] = None) {} diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/testing/DummyNode.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/testing/DummyNode.scala index 84c572f9..fea5a177 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/testing/DummyNode.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/testing/DummyNode.scala @@ -7,54 +7,54 @@ import java.util /** mixin trait for test nodes */ trait DummyNodeImpl extends StoredNode: - // Members declared in overflowdb.Element - def graph(): overflowdb.Graph = ??? - def property[A](x$1: overflowdb.PropertyKey[A]): A = ??? - def property(x$1: String): Object = ??? - def propertyKeys(): java.util.Set[String] = ??? - def propertiesMap(): java.util.Map[String, Object] = ??? - def propertyOption(x$1: String): java.util.Optional[Object] = ??? - def propertyOption[A](x$1: overflowdb.PropertyKey[A]): java.util.Optional[A] = ??? - override def addEdgeImpl(label: String, inNode: Node, keyValues: Any*): Edge = ??? - override def addEdgeImpl( - label: String, - inNode: Node, - keyValues: util.Map[String, AnyRef] - ): Edge = ??? - override def addEdgeSilentImpl(label: String, inNode: Node, keyValues: Any*): Unit = ??? - override def addEdgeSilentImpl( - label: String, - inNode: Node, - keyValues: util.Map[String, AnyRef] - ): Unit = ??? - override def setPropertyImpl(key: String, value: Any): Unit = ??? - override def setPropertyImpl[A](key: PropertyKey[A], value: A): Unit = ??? - override def setPropertyImpl(property: Property[?]): Unit = ??? - override def removePropertyImpl(key: String): Unit = ??? - override def removeImpl(): Unit = ??? + // Members declared in overflowdb.Element + def graph(): overflowdb.Graph = ??? + def property[A](x$1: overflowdb.PropertyKey[A]): A = ??? + def property(x$1: String): Object = ??? + def propertyKeys(): java.util.Set[String] = ??? + def propertiesMap(): java.util.Map[String, Object] = ??? + def propertyOption(x$1: String): java.util.Optional[Object] = ??? + def propertyOption[A](x$1: overflowdb.PropertyKey[A]): java.util.Optional[A] = ??? + override def addEdgeImpl(label: String, inNode: Node, keyValues: Any*): Edge = ??? + override def addEdgeImpl( + label: String, + inNode: Node, + keyValues: util.Map[String, AnyRef] + ): Edge = ??? + override def addEdgeSilentImpl(label: String, inNode: Node, keyValues: Any*): Unit = ??? + override def addEdgeSilentImpl( + label: String, + inNode: Node, + keyValues: util.Map[String, AnyRef] + ): Unit = ??? + override def setPropertyImpl(key: String, value: Any): Unit = ??? + override def setPropertyImpl[A](key: PropertyKey[A], value: A): Unit = ??? + override def setPropertyImpl(property: Property[?]): Unit = ??? + override def removePropertyImpl(key: String): Unit = ??? + override def removeImpl(): Unit = ??? - // Members declared in scala.Equals - def canEqual(that: Any): Boolean = ??? + // Members declared in scala.Equals + def canEqual(that: Any): Boolean = ??? - def both(x$1: String*): java.util.Iterator[overflowdb.Node] = ??? - def both(): java.util.Iterator[overflowdb.Node] = ??? - def bothE(x$1: String*): java.util.Iterator[overflowdb.Edge] = ??? - def bothE(): java.util.Iterator[overflowdb.Edge] = ??? - def id(): Long = ??? - def in(x$1: String*): java.util.Iterator[overflowdb.Node] = ??? - def in(): java.util.Iterator[overflowdb.Node] = ??? - def inE(x$1: String*): java.util.Iterator[overflowdb.Edge] = ??? - def inE(): java.util.Iterator[overflowdb.Edge] = ??? - def out(x$1: String*): java.util.Iterator[overflowdb.Node] = ??? - def out(): java.util.Iterator[overflowdb.Node] = ??? - def outE(x$1: String*): java.util.Iterator[overflowdb.Edge] = ??? - def outE(): java.util.Iterator[overflowdb.Edge] = ??? + def both(x$1: String*): java.util.Iterator[overflowdb.Node] = ??? + def both(): java.util.Iterator[overflowdb.Node] = ??? + def bothE(x$1: String*): java.util.Iterator[overflowdb.Edge] = ??? + def bothE(): java.util.Iterator[overflowdb.Edge] = ??? + def id(): Long = ??? + def in(x$1: String*): java.util.Iterator[overflowdb.Node] = ??? + def in(): java.util.Iterator[overflowdb.Node] = ??? + def inE(x$1: String*): java.util.Iterator[overflowdb.Edge] = ??? + def inE(): java.util.Iterator[overflowdb.Edge] = ??? + def out(x$1: String*): java.util.Iterator[overflowdb.Node] = ??? + def out(): java.util.Iterator[overflowdb.Node] = ??? + def outE(x$1: String*): java.util.Iterator[overflowdb.Edge] = ??? + def outE(): java.util.Iterator[overflowdb.Edge] = ??? - // Members declared in scala.Product - def productArity: Int = ??? - def productElement(n: Int): Any = ??? + // Members declared in scala.Product + def productArity: Int = ??? + def productElement(n: Int): Any = ??? - // Members declared in io.shiftleft.codepropertygraph.generated.nodes.StoredNode - def productElementLabel(n: Int): String = ??? - def valueMap: java.util.Map[String, AnyRef] = ??? + // Members declared in io.shiftleft.codepropertygraph.generated.nodes.StoredNode + def productElementLabel(n: Int): String = ??? + def valueMap: java.util.Map[String, AnyRef] = ??? end DummyNodeImpl diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/testing/package.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/testing/package.scala index 18af5f03..1ee58c0a 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/testing/package.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/testing/package.scala @@ -10,206 +10,206 @@ import overflowdb.BatchedUpdate.DiffGraphBuilder package object testing: - object MockCpg: - - def apply(): MockCpg = new MockCpg - - def apply(f: (DiffGraphBuilder, Cpg) => Unit): MockCpg = new MockCpg().withCustom(f) - - case class MockCpg(cpg: Cpg = Cpg.emptyCpg): - - def withMetaData(language: String = Languages.C): MockCpg = withMetaData(language, Nil) - - def withMetaData(language: String, overlays: List[String]): MockCpg = - withCustom { (diffGraph, _) => - diffGraph.addNode(NewMetaData().language(language).overlays(overlays)) - } - - def withFile(filename: String): MockCpg = - withCustom { (graph, _) => - graph.addNode(NewFile().name(filename)) - } - - def withNamespace(name: String, inFile: Option[String] = None): MockCpg = - withCustom { (graph, _) => - val namespaceBlock = NewNamespaceBlock().name(name) - val namespace = NewNamespace().name(name) - graph.addNode(namespaceBlock) - graph.addNode(namespace) - graph.addEdge(namespaceBlock, namespace, EdgeTypes.REF) - if inFile.isDefined then - val fileNode = cpg.file.name(inFile.get).head - graph.addEdge(namespaceBlock, fileNode, EdgeTypes.SOURCE_FILE) - } - - def withTypeDecl( - name: String, - isExternal: Boolean = false, - inNamespace: Option[String] = None, - inFile: Option[String] = None - ): MockCpg = - withCustom { (graph, _) => - val typeNode = NewType().name(name) - val typeDeclNode = NewTypeDecl() - .name(name) - .fullName(name) - .isExternal(isExternal) - - val member = NewMember().name("amember") - val modifier = NewModifier().modifierType(ModifierTypes.STATIC) - - graph.addNode(typeDeclNode) - graph.addNode(typeNode) - graph.addNode(member) - graph.addNode(modifier) - graph.addEdge(typeNode, typeDeclNode, EdgeTypes.REF) - graph.addEdge(typeDeclNode, member, EdgeTypes.AST) - graph.addEdge(member, modifier, EdgeTypes.AST) - - if inNamespace.isDefined then - val namespaceBlock = cpg.namespaceBlock(inNamespace.get).head - graph.addEdge(namespaceBlock, typeDeclNode, EdgeTypes.AST) - if inFile.isDefined then - val fileNode = cpg.file.name(inFile.get).head - graph.addEdge(typeDeclNode, fileNode, EdgeTypes.SOURCE_FILE) - } - - def withMethod( - name: String, - external: Boolean = false, - inTypeDecl: Option[String] = None, - fileName: String = "" - ): MockCpg = - withCustom { (graph, _) => - val retParam = NewMethodReturn().typeFullName("int").order(10) - val param = NewMethodParameterIn().order(1).index(1).name("param1") - val paramType = NewType().name("paramtype") - val paramOut = NewMethodParameterOut().name("param1").order(1) - val method = - NewMethod().isExternal(external).name(name).fullName(name).signature( - "asignature" - ).filename(fileName) - val block = NewBlock().typeFullName("int") - val modifier = NewModifier().modifierType("modifiertype") - - graph.addNode(method) - graph.addNode(retParam) - graph.addNode(param) - graph.addNode(paramType) - graph.addNode(paramOut) - graph.addNode(block) - graph.addNode(modifier) - graph.addEdge(method, retParam, EdgeTypes.AST) - graph.addEdge(method, param, EdgeTypes.AST) - graph.addEdge(param, paramOut, EdgeTypes.PARAMETER_LINK) - graph.addEdge(method, block, EdgeTypes.AST) - graph.addEdge(param, paramType, EdgeTypes.EVAL_TYPE) - graph.addEdge(paramOut, paramType, EdgeTypes.EVAL_TYPE) - graph.addEdge(method, modifier, EdgeTypes.AST) - - if inTypeDecl.isDefined then - val typeDeclNode = cpg.typeDecl.name(inTypeDecl.get).head - graph.addEdge(typeDeclNode, method, EdgeTypes.AST) - } - - def withTagsOnMethod( - methodName: String, - methodTags: List[(String, String)] = List(), - paramTags: List[(String, String)] = List() - ): MockCpg = - withCustom { (graph, cpg) => - implicit val diffGraph: DiffGraphBuilder = graph - methodTags.foreach { case (k, v) => - cpg.method.name(methodName).newTagNodePair(k, v).store()(diffGraph) - } - paramTags.foreach { case (k, v) => - cpg.method.name(methodName).parameter.newTagNodePair(k, v).store()(diffGraph) - } - } - - def withCallInMethod( - methodName: String, - callName: String, - code: Option[String] = None - ): MockCpg = - withCustom { (graph, cpg) => - val methodNode = cpg.method.name(methodName).head - val blockNode = methodNode.block - val callNode = NewCall().name(callName).code(code.getOrElse(callName)) - graph.addNode(callNode) - graph.addEdge(blockNode, callNode, EdgeTypes.AST) - graph.addEdge(methodNode, callNode, EdgeTypes.CONTAINS) - } - - def withMethodCall( - calledMethod: String, - callingMethod: String, - code: Option[String] = None - ): MockCpg = - withCustom { (graph, cpg) => - val callingMethodNode = cpg.method.name(callingMethod).head - val calledMethodNode = cpg.method.name(calledMethod).head - val callNode = NewCall().name(calledMethod).code(code.getOrElse(calledMethod)) - graph.addEdge(callNode, calledMethodNode, EdgeTypes.CALL) - graph.addEdge(callingMethodNode, callNode, EdgeTypes.CONTAINS) - } - - def withLocalInMethod(methodName: String, localName: String): MockCpg = - withCustom { (graph, cpg) => - val methodNode = cpg.method.name(methodName).head - val blockNode = methodNode.block - val typeNode = NewType().name("alocaltype") - val localNode = NewLocal().name(localName).typeFullName("alocaltype") - graph.addNode(localNode) - graph.addNode(typeNode) - graph.addEdge(blockNode, localNode, EdgeTypes.AST) - graph.addEdge(localNode, typeNode, EdgeTypes.EVAL_TYPE) - } - - def withLiteralArgument(callName: String, literalCode: String): MockCpg = - withCustom { (graph, cpg) => - val callNode = cpg.call.name(callName).head - val methodNode = callNode.method - val literalNode = NewLiteral().code(literalCode) - val typeDecl = NewTypeDecl() - .name("ATypeDecl") - .fullName("ATypeDecl") - - graph.addNode(typeDecl) - graph.addNode(literalNode) - graph.addEdge(callNode, literalNode, EdgeTypes.AST) - graph.addEdge(methodNode, literalNode, EdgeTypes.CONTAINS) - } - - def withIdentifierArgument(callName: String, name: String, index: Int = 1): MockCpg = - withArgument(callName, NewIdentifier().name(name).argumentIndex(index)) - - def withCallArgument( - callName: String, - callArgName: String, - code: String = "", - index: Int = 1 - ): MockCpg = - withArgument(callName, NewCall().name(callArgName).code(code).argumentIndex(index)) - - def withArgument(callName: String, newNode: NewNode): MockCpg = withCustom { (graph, cpg) => - val callNode = cpg.call.name(callName).head - val methodNode = callNode.method - val typeDecl = NewTypeDecl().name("abc") - graph.addEdge(callNode, newNode, EdgeTypes.AST) - graph.addEdge(callNode, newNode, EdgeTypes.ARGUMENT) - graph.addEdge(methodNode, newNode, EdgeTypes.CONTAINS) - graph.addEdge(newNode, typeDecl, EdgeTypes.REF) - graph.addNode(newNode) + object MockCpg: + + def apply(): MockCpg = new MockCpg + + def apply(f: (DiffGraphBuilder, Cpg) => Unit): MockCpg = new MockCpg().withCustom(f) + + case class MockCpg(cpg: Cpg = Cpg.emptyCpg): + + def withMetaData(language: String = Languages.C): MockCpg = withMetaData(language, Nil) + + def withMetaData(language: String, overlays: List[String]): MockCpg = + withCustom { (diffGraph, _) => + diffGraph.addNode(NewMetaData().language(language).overlays(overlays)) + } + + def withFile(filename: String): MockCpg = + withCustom { (graph, _) => + graph.addNode(NewFile().name(filename)) + } + + def withNamespace(name: String, inFile: Option[String] = None): MockCpg = + withCustom { (graph, _) => + val namespaceBlock = NewNamespaceBlock().name(name) + val namespace = NewNamespace().name(name) + graph.addNode(namespaceBlock) + graph.addNode(namespace) + graph.addEdge(namespaceBlock, namespace, EdgeTypes.REF) + if inFile.isDefined then + val fileNode = cpg.file.name(inFile.get).head + graph.addEdge(namespaceBlock, fileNode, EdgeTypes.SOURCE_FILE) + } + + def withTypeDecl( + name: String, + isExternal: Boolean = false, + inNamespace: Option[String] = None, + inFile: Option[String] = None + ): MockCpg = + withCustom { (graph, _) => + val typeNode = NewType().name(name) + val typeDeclNode = NewTypeDecl() + .name(name) + .fullName(name) + .isExternal(isExternal) + + val member = NewMember().name("amember") + val modifier = NewModifier().modifierType(ModifierTypes.STATIC) + + graph.addNode(typeDeclNode) + graph.addNode(typeNode) + graph.addNode(member) + graph.addNode(modifier) + graph.addEdge(typeNode, typeDeclNode, EdgeTypes.REF) + graph.addEdge(typeDeclNode, member, EdgeTypes.AST) + graph.addEdge(member, modifier, EdgeTypes.AST) + + if inNamespace.isDefined then + val namespaceBlock = cpg.namespaceBlock(inNamespace.get).head + graph.addEdge(namespaceBlock, typeDeclNode, EdgeTypes.AST) + if inFile.isDefined then + val fileNode = cpg.file.name(inFile.get).head + graph.addEdge(typeDeclNode, fileNode, EdgeTypes.SOURCE_FILE) + } + + def withMethod( + name: String, + external: Boolean = false, + inTypeDecl: Option[String] = None, + fileName: String = "" + ): MockCpg = + withCustom { (graph, _) => + val retParam = NewMethodReturn().typeFullName("int").order(10) + val param = NewMethodParameterIn().order(1).index(1).name("param1") + val paramType = NewType().name("paramtype") + val paramOut = NewMethodParameterOut().name("param1").order(1) + val method = + NewMethod().isExternal(external).name(name).fullName(name).signature( + "asignature" + ).filename(fileName) + val block = NewBlock().typeFullName("int") + val modifier = NewModifier().modifierType("modifiertype") + + graph.addNode(method) + graph.addNode(retParam) + graph.addNode(param) + graph.addNode(paramType) + graph.addNode(paramOut) + graph.addNode(block) + graph.addNode(modifier) + graph.addEdge(method, retParam, EdgeTypes.AST) + graph.addEdge(method, param, EdgeTypes.AST) + graph.addEdge(param, paramOut, EdgeTypes.PARAMETER_LINK) + graph.addEdge(method, block, EdgeTypes.AST) + graph.addEdge(param, paramType, EdgeTypes.EVAL_TYPE) + graph.addEdge(paramOut, paramType, EdgeTypes.EVAL_TYPE) + graph.addEdge(method, modifier, EdgeTypes.AST) + + if inTypeDecl.isDefined then + val typeDeclNode = cpg.typeDecl.name(inTypeDecl.get).head + graph.addEdge(typeDeclNode, method, EdgeTypes.AST) + } + + def withTagsOnMethod( + methodName: String, + methodTags: List[(String, String)] = List(), + paramTags: List[(String, String)] = List() + ): MockCpg = + withCustom { (graph, cpg) => + implicit val diffGraph: DiffGraphBuilder = graph + methodTags.foreach { case (k, v) => + cpg.method.name(methodName).newTagNodePair(k, v).store()(diffGraph) + } + paramTags.foreach { case (k, v) => + cpg.method.name(methodName).parameter.newTagNodePair(k, v).store()(diffGraph) + } + } + + def withCallInMethod( + methodName: String, + callName: String, + code: Option[String] = None + ): MockCpg = + withCustom { (graph, cpg) => + val methodNode = cpg.method.name(methodName).head + val blockNode = methodNode.block + val callNode = NewCall().name(callName).code(code.getOrElse(callName)) + graph.addNode(callNode) + graph.addEdge(blockNode, callNode, EdgeTypes.AST) + graph.addEdge(methodNode, callNode, EdgeTypes.CONTAINS) + } + + def withMethodCall( + calledMethod: String, + callingMethod: String, + code: Option[String] = None + ): MockCpg = + withCustom { (graph, cpg) => + val callingMethodNode = cpg.method.name(callingMethod).head + val calledMethodNode = cpg.method.name(calledMethod).head + val callNode = NewCall().name(calledMethod).code(code.getOrElse(calledMethod)) + graph.addEdge(callNode, calledMethodNode, EdgeTypes.CALL) + graph.addEdge(callingMethodNode, callNode, EdgeTypes.CONTAINS) + } + + def withLocalInMethod(methodName: String, localName: String): MockCpg = + withCustom { (graph, cpg) => + val methodNode = cpg.method.name(methodName).head + val blockNode = methodNode.block + val typeNode = NewType().name("alocaltype") + val localNode = NewLocal().name(localName).typeFullName("alocaltype") + graph.addNode(localNode) + graph.addNode(typeNode) + graph.addEdge(blockNode, localNode, EdgeTypes.AST) + graph.addEdge(localNode, typeNode, EdgeTypes.EVAL_TYPE) + } + + def withLiteralArgument(callName: String, literalCode: String): MockCpg = + withCustom { (graph, cpg) => + val callNode = cpg.call.name(callName).head + val methodNode = callNode.method + val literalNode = NewLiteral().code(literalCode) + val typeDecl = NewTypeDecl() + .name("ATypeDecl") + .fullName("ATypeDecl") + + graph.addNode(typeDecl) + graph.addNode(literalNode) + graph.addEdge(callNode, literalNode, EdgeTypes.AST) + graph.addEdge(methodNode, literalNode, EdgeTypes.CONTAINS) } - def withCustom(f: (DiffGraphBuilder, Cpg) => Unit): MockCpg = - val diffGraph = new DiffGraphBuilder - f(diffGraph, cpg) - class MyPass extends CpgPass(cpg): - override def run(builder: BatchedUpdate.DiffGraphBuilder): Unit = - builder.absorb(diffGraph) - new MyPass().createAndApply() - this - end MockCpg + def withIdentifierArgument(callName: String, name: String, index: Int = 1): MockCpg = + withArgument(callName, NewIdentifier().name(name).argumentIndex(index)) + + def withCallArgument( + callName: String, + callArgName: String, + code: String = "", + index: Int = 1 + ): MockCpg = + withArgument(callName, NewCall().name(callArgName).code(code).argumentIndex(index)) + + def withArgument(callName: String, newNode: NewNode): MockCpg = withCustom { (graph, cpg) => + val callNode = cpg.call.name(callName).head + val methodNode = callNode.method + val typeDecl = NewTypeDecl().name("abc") + graph.addEdge(callNode, newNode, EdgeTypes.AST) + graph.addEdge(callNode, newNode, EdgeTypes.ARGUMENT) + graph.addEdge(methodNode, newNode, EdgeTypes.CONTAINS) + graph.addEdge(newNode, typeDecl, EdgeTypes.REF) + graph.addNode(newNode) + } + + def withCustom(f: (DiffGraphBuilder, Cpg) => Unit): MockCpg = + val diffGraph = new DiffGraphBuilder + f(diffGraph, cpg) + class MyPass extends CpgPass(cpg): + override def run(builder: BatchedUpdate.DiffGraphBuilder): Unit = + builder.absorb(diffGraph) + new MyPass().createAndApply() + this + end MockCpg end testing diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/utils/Fingerprinting.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/utils/Fingerprinting.scala index 986a2a45..5ca71a8e 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/utils/Fingerprinting.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/utils/Fingerprinting.scala @@ -7,7 +7,7 @@ import java.security.MessageDigest object Fingerprinting: - def calculate_hash(content: String): String = - MessageDigest.getInstance("SHA-256") - .digest(content.getBytes("UTF-8")) - .map("%02x".format(_)).mkString + def calculate_hash(content: String): String = + MessageDigest.getInstance("SHA-256") + .digest(content.getBytes("UTF-8")) + .map("%02x".format(_)).mkString diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/utils/MemberAccess.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/utils/MemberAccess.scala index 74bef7a1..8b80d32c 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/utils/MemberAccess.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/utils/MemberAccess.scala @@ -4,32 +4,32 @@ import io.shiftleft.codepropertygraph.generated.Operators object MemberAccess: - /** For a given name, determine whether it is the name of a "member access" operation, e.g., - * ".memberAccess". - */ - def isGenericMemberAccessName(name: String): Boolean = - (name == Operators.memberAccess) || - (name == Operators.indirectComputedMemberAccess) || - (name == Operators.indirectMemberAccess) || - (name == Operators.computedMemberAccess) || - (name == Operators.indirection) || - (name == Operators.addressOf) || - (name == Operators.fieldAccess) || - (name == Operators.indirectFieldAccess) || - (name == Operators.indexAccess) || - (name == Operators.indirectIndexAccess) || - (name == Operators.pointerShift) || - (name == Operators.getElementPtr) + /** For a given name, determine whether it is the name of a "member access" operation, e.g., + * ".memberAccess". + */ + def isGenericMemberAccessName(name: String): Boolean = + (name == Operators.memberAccess) || + (name == Operators.indirectComputedMemberAccess) || + (name == Operators.indirectMemberAccess) || + (name == Operators.computedMemberAccess) || + (name == Operators.indirection) || + (name == Operators.addressOf) || + (name == Operators.fieldAccess) || + (name == Operators.indirectFieldAccess) || + (name == Operators.indexAccess) || + (name == Operators.indirectIndexAccess) || + (name == Operators.pointerShift) || + (name == Operators.getElementPtr) - def isFieldAccess(name: String): Boolean = - (name == Operators.memberAccess) || - (name == Operators.indirectComputedMemberAccess) || - (name == Operators.indirectMemberAccess) || - (name == Operators.computedMemberAccess) || - (name == Operators.indirection) || - (name == Operators.fieldAccess) || - (name == Operators.indirectFieldAccess) || - (name == Operators.indexAccess) || - (name == Operators.indirectIndexAccess) || - (name == Operators.getElementPtr) + def isFieldAccess(name: String): Boolean = + (name == Operators.memberAccess) || + (name == Operators.indirectComputedMemberAccess) || + (name == Operators.indirectMemberAccess) || + (name == Operators.computedMemberAccess) || + (name == Operators.indirection) || + (name == Operators.fieldAccess) || + (name == Operators.indirectFieldAccess) || + (name == Operators.indexAccess) || + (name == Operators.indirectIndexAccess) || + (name == Operators.getElementPtr) end MemberAccess diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/utils/SecureXmlParsing.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/utils/SecureXmlParsing.scala index 8a18ff7e..47f7a80e 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/utils/SecureXmlParsing.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/utils/SecureXmlParsing.scala @@ -5,19 +5,19 @@ import scala.util.Try import scala.xml.{Elem, XML} object SecureXmlParsing: - def parseXml(content: String): Option[Elem] = - Try { - val spf = SAXParserFactory.newInstance() + def parseXml(content: String): Option[Elem] = + Try { + val spf = SAXParserFactory.newInstance() - spf.setValidating(false) - spf.setNamespaceAware(false) - spf.setXIncludeAware(false) - spf.setFeature("http://xml.org/sax/features/validation", false) - spf.setFeature("http://apache.org/xml/features/disallow-doctype-decl", false) - spf.setFeature("http://apache.org/xml/features/nonvalidating/load-dtd-grammar", false) - spf.setFeature("http://apache.org/xml/features/nonvalidating/load-external-dtd", false) - spf.setFeature("http://xml.org/sax/features/external-parameter-entities", false) - spf.setFeature("http://xml.org/sax/features/external-general-entities", false) + spf.setValidating(false) + spf.setNamespaceAware(false) + spf.setXIncludeAware(false) + spf.setFeature("http://xml.org/sax/features/validation", false) + spf.setFeature("http://apache.org/xml/features/disallow-doctype-decl", false) + spf.setFeature("http://apache.org/xml/features/nonvalidating/load-dtd-grammar", false) + spf.setFeature("http://apache.org/xml/features/nonvalidating/load-external-dtd", false) + spf.setFeature("http://xml.org/sax/features/external-parameter-entities", false) + spf.setFeature("http://xml.org/sax/features/external-general-entities", false) - XML.withSAXParser(spf.newSAXParser()).loadString(content) - }.toOption + XML.withSAXParser(spf.newSAXParser()).loadString(content) + }.toOption diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/utils/Statements.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/utils/Statements.scala index 4922152e..9174d0ae 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/utils/Statements.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/utils/Statements.scala @@ -4,5 +4,5 @@ import io.shiftleft.codepropertygraph.Cpg import io.shiftleft.semanticcpg.language.* object Statements: - def countAll(cpg: Cpg): Long = - cpg.method.topLevelExpressions.size + def countAll(cpg: Cpg): Long = + cpg.method.topLevelExpressions.size diff --git a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/utils/Torch.scala b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/utils/Torch.scala index aca4c844..a4c6bb66 100644 --- a/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/utils/Torch.scala +++ b/semanticcpg/src/main/scala/io/shiftleft/semanticcpg/utils/Torch.scala @@ -10,7 +10,7 @@ import scala.util.{Failure, Success, Try} import java.nio.file.Path object Torch: - CPythonInterpreter.execManyLines(""" + CPythonInterpreter.execManyLines(""" | |SCIENCE_PACK_AVAILABLE = True |try: @@ -19,91 +19,91 @@ object Torch: | SCIENCE_PACK_AVAILABLE = False |""".stripMargin) - def convert_graphml(gml_file: Path) = - py.Dynamic.global.convert_graphml(gml_file.toAbsolutePath.toString) + def convert_graphml(gml_file: Path) = + py.Dynamic.global.convert_graphml(gml_file.toAbsolutePath.toString) - def to_pyg(gml_file: Path) = py.Dynamic.global.to_pyg(convert_graphml(gml_file)) + def to_pyg(gml_file: Path) = py.Dynamic.global.to_pyg(convert_graphml(gml_file)) - def diff_graph( - first_gml_file: Path, - second_gml_file: Path, - include_common: Boolean = false, - as_dict: Boolean = false - ) = - val first_graph = py.Dynamic.global.convert_graphml(first_gml_file.toAbsolutePath.toString) + def diff_graph( + first_gml_file: Path, + second_gml_file: Path, + include_common: Boolean = false, + as_dict: Boolean = false + ) = + val first_graph = py.Dynamic.global.convert_graphml(first_gml_file.toAbsolutePath.toString) + val second_graph = + py.Dynamic.global.convert_graphml(second_gml_file.toAbsolutePath.toString) + py.Dynamic.global.diff_graph(first_graph, second_graph, include_common, as_dict) + + def is_similar( + first_gml_file: Path, + second_gml_file: Path, + edit_distance: Int = 10, + upper_bound: Int = 500, + timeout: Int = 5 + ): Boolean = + val first_graph = py.Dynamic.global.convert_graphml(first_gml_file.toAbsolutePath.toString) + val second_graph = + py.Dynamic.global.convert_graphml(second_gml_file.toAbsolutePath.toString) + py.Dynamic.global + .is_similar( + first_graph, + second_graph, + edit_distance = edit_distance, + upper_bound = upper_bound, + timeout = timeout + ) + .as[Boolean] + end is_similar + + def is_similar( + first_result: ExportResult, + second_result: ExportResult, + edit_distance: Int + ): Boolean = + if first_result.files.nonEmpty && second_result.files.nonEmpty then + val first_gml_file = first_result.files.head + val second_gml_file = second_result.files.head + val first_graph = + py.Dynamic.global.convert_graphml(first_gml_file.toAbsolutePath.toString) val second_graph = py.Dynamic.global.convert_graphml(second_gml_file.toAbsolutePath.toString) - py.Dynamic.global.diff_graph(first_graph, second_graph, include_common, as_dict) + py.Dynamic.global.is_similar( + first_graph, + second_graph, + edit_distance = edit_distance + ).as[Boolean] + else + false - def is_similar( - first_gml_file: Path, - second_gml_file: Path, - edit_distance: Int = 10, - upper_bound: Int = 500, - timeout: Int = 5 - ): Boolean = - val first_graph = py.Dynamic.global.convert_graphml(first_gml_file.toAbsolutePath.toString) + def edit_distance( + first_result: ExportResult, + second_result: ExportResult, + upper_bound: Int = 500, + timeout: Int = 5 + ): Double = + if first_result.files.nonEmpty && second_result.files.nonEmpty then + val first_gml_file = first_result.files.head + val second_gml_file = second_result.files.head + val first_graph = + py.Dynamic.global.convert_graphml(first_gml_file.toAbsolutePath.toString) val second_graph = py.Dynamic.global.convert_graphml(second_gml_file.toAbsolutePath.toString) - py.Dynamic.global - .is_similar( - first_graph, - second_graph, - edit_distance = edit_distance, - upper_bound = upper_bound, - timeout = timeout - ) - .as[Boolean] - end is_similar - - def is_similar( - first_result: ExportResult, - second_result: ExportResult, - edit_distance: Int - ): Boolean = - if first_result.files.nonEmpty && second_result.files.nonEmpty then - val first_gml_file = first_result.files.head - val second_gml_file = second_result.files.head - val first_graph = - py.Dynamic.global.convert_graphml(first_gml_file.toAbsolutePath.toString) - val second_graph = - py.Dynamic.global.convert_graphml(second_gml_file.toAbsolutePath.toString) - py.Dynamic.global.is_similar( - first_graph, - second_graph, - edit_distance = edit_distance - ).as[Boolean] - else - false - - def edit_distance( - first_result: ExportResult, - second_result: ExportResult, - upper_bound: Int = 500, - timeout: Int = 5 - ): Double = - if first_result.files.nonEmpty && second_result.files.nonEmpty then - val first_gml_file = first_result.files.head - val second_gml_file = second_result.files.head - val first_graph = - py.Dynamic.global.convert_graphml(first_gml_file.toAbsolutePath.toString) - val second_graph = - py.Dynamic.global.convert_graphml(second_gml_file.toAbsolutePath.toString) - py.Dynamic.global.ged( - first_graph, - second_graph, - upper_bound = upper_bound, - timeout = timeout - ).as[Double] - else - -1.0 + py.Dynamic.global.ged( + first_graph, + second_graph, + upper_bound = upper_bound, + timeout = timeout + ).as[Double] + else + -1.0 - def generate_sp_model( - filename: String, - vocab_size: Int = 20000, - model_type: String = "unigram", - model_prefix: String = "m_user" - ) = py.Dynamic.global.generate_sp_model(filename, vocab_size, model_type, model_prefix) + def generate_sp_model( + filename: String, + vocab_size: Int = 20000, + model_type: String = "unigram", + model_prefix: String = "m_user" + ) = py.Dynamic.global.generate_sp_model(filename, vocab_size, model_type, model_prefix) - def load_sp_model(filename: String) = py.Dynamic.global.load_sp_model(filename) + def load_sp_model(filename: String) = py.Dynamic.global.load_sp_model(filename) end Torch